diff --git a/MAINTAINERS.md b/MAINTAINERS.md index 093c82b3afe8..c6672c0a3efe 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,17 +8,18 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [canguler](https://github.com/canguler), Google LLC + - [cesarghali](https://github.com/cesarghali), Google LLC - [dfawley](https://github.com/dfawley), Google LLC - [easwars](https://github.com/easwars), Google LLC -- [jadekler](https://github.com/jadekler), Google LLC - [menghanl](https://github.com/menghanl), Google LLC - [srini100](https://github.com/srini100), Google LLC ## Emeritus Maintainers (in alphabetical order) - [adelez](https://github.com/adelez), Google LLC +- [canguler](https://github.com/canguler), Google LLC - [iamqizhao](https://github.com/iamqizhao), Google LLC +- [jadekler](https://github.com/jadekler), Google LLC - [jtattermusch](https://github.com/jtattermusch), Google LLC - [lyuxuan](https://github.com/lyuxuan), Google LLC - [makmukhi](https://github.com/makmukhi), Google LLC diff --git a/README.md b/README.md index 3949a683fb58..0e6ae69a5846 100644 --- a/README.md +++ b/README.md @@ -136,6 +136,6 @@ errors. [Go module]: https://github.com/golang/go/wiki/Modules [gRPC]: https://grpc.io [Go gRPC docs]: https://grpc.io/docs/languages/go -[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5652536396611584&widget=490377658&container=1286539696 +[Performance benchmark]: https://performance-dot-grpc-testing.appspot.com/explore?dashboard=5180705743044608 [quick start]: https://grpc.io/docs/languages/go/quickstart [go-releases]: https://golang.org/doc/devel/release.html diff --git a/authz/rbac_translator.go b/authz/rbac_translator.go new file mode 100644 index 000000000000..8dc764896053 --- /dev/null +++ b/authz/rbac_translator.go @@ -0,0 +1,301 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package authz exposes methods to manage authorization within gRPC. +// +// Experimental +// +// Notice: This package is EXPERIMENTAL and may be changed or removed +// in a later release. +package authz + +import ( + "encoding/json" + "fmt" + "strings" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +type header struct { + Key string + Values []string +} + +type peer struct { + Principals []string +} + +type request struct { + Paths []string + Headers []header +} + +type rule struct { + Name string + Source peer + Request request +} + +// Represents the SDK authorization policy provided by user. +type authorizationPolicy struct { + Name string + DenyRules []rule `json:"deny_rules"` + AllowRules []rule `json:"allow_rules"` +} + +func principalOr(principals []*v3rbacpb.Principal) *v3rbacpb.Principal { + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_OrIds{ + OrIds: &v3rbacpb.Principal_Set{ + Ids: principals, + }, + }, + } +} + +func permissionOr(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_OrRules{ + OrRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func permissionAnd(permission []*v3rbacpb.Permission) *v3rbacpb.Permission { + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_AndRules{ + AndRules: &v3rbacpb.Permission_Set{ + Rules: permission, + }, + }, + } +} + +func getStringMatcher(value string) *v3matcherpb.StringMatcher { + switch { + case value == "*": + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: suffix}, + } + default: + return &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: value}, + } + } +} + +func getHeaderMatcher(key, value string) *v3routepb.HeaderMatcher { + switch { + case value == "*": + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{}, + } + case strings.HasSuffix(value, "*"): + prefix := strings.TrimSuffix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: prefix}, + } + case strings.HasPrefix(value, "*"): + suffix := strings.TrimPrefix(value, "*") + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: suffix}, + } + default: + return &v3routepb.HeaderMatcher{ + Name: key, + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: value}, + } + } +} + +func parsePrincipalNames(principalNames []string) []*v3rbacpb.Principal { + var ps []*v3rbacpb.Principal + for _, principalName := range principalNames { + newPrincipalName := &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{ + PrincipalName: getStringMatcher(principalName), + }, + }} + ps = append(ps, newPrincipalName) + } + return ps +} + +func parsePeer(source peer) (*v3rbacpb.Principal, error) { + if len(source.Principals) > 0 { + return principalOr(parsePrincipalNames(source.Principals)), nil + } + return &v3rbacpb.Principal{ + Identifier: &v3rbacpb.Principal_Any{ + Any: true, + }, + }, nil +} + +func parsePaths(paths []string) []*v3rbacpb.Permission { + var ps []*v3rbacpb.Permission + for _, path := range paths { + newPath := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{ + Rule: &v3matcherpb.PathMatcher_Path{Path: getStringMatcher(path)}}}} + ps = append(ps, newPath) + } + return ps +} + +func parseHeaderValues(key string, values []string) []*v3rbacpb.Permission { + var vs []*v3rbacpb.Permission + for _, value := range values { + newHeader := &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Header{ + Header: getHeaderMatcher(key, value)}} + vs = append(vs, newHeader) + } + return vs +} + +var unsupportedHeaders = map[string]bool{ + "host": true, + "connection": true, + "keep-alive": true, + "proxy-authenticate": true, + "proxy-authorization": true, + "te": true, + "trailer": true, + "transfer-encoding": true, + "upgrade": true, +} + +func unsupportedHeader(key string) bool { + return key[0] == ':' || strings.HasPrefix(key, "grpc-") || unsupportedHeaders[key] +} + +func parseHeaders(headers []header) ([]*v3rbacpb.Permission, error) { + var hs []*v3rbacpb.Permission + for i, header := range headers { + if header.Key == "" { + return nil, fmt.Errorf(`"headers" %d: "key" is not present`, i) + } + header.Key = strings.ToLower(header.Key) + if unsupportedHeader(header.Key) { + return nil, fmt.Errorf(`"headers" %d: unsupported "key" %s`, i, header.Key) + } + if len(header.Values) == 0 { + return nil, fmt.Errorf(`"headers" %d: "values" is not present`, i) + } + values := parseHeaderValues(header.Key, header.Values) + hs = append(hs, permissionOr(values)) + } + return hs, nil +} + +func parseRequest(request request) (*v3rbacpb.Permission, error) { + var and []*v3rbacpb.Permission + if len(request.Paths) > 0 { + and = append(and, permissionOr(parsePaths(request.Paths))) + } + if len(request.Headers) > 0 { + headers, err := parseHeaders(request.Headers) + if err != nil { + return nil, err + } + and = append(and, permissionAnd(headers)) + } + if len(and) > 0 { + return permissionAnd(and), nil + } + return &v3rbacpb.Permission{ + Rule: &v3rbacpb.Permission_Any{ + Any: true, + }, + }, nil +} + +func parseRules(rules []rule, prefixName string) (map[string]*v3rbacpb.Policy, error) { + policies := make(map[string]*v3rbacpb.Policy) + for i, rule := range rules { + if rule.Name == "" { + return policies, fmt.Errorf(`%d: "name" is not present`, i) + } + principal, err := parsePeer(rule.Source) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + permission, err := parseRequest(rule.Request) + if err != nil { + return nil, fmt.Errorf("%d: %v", i, err) + } + policyName := prefixName + "_" + rule.Name + policies[policyName] = &v3rbacpb.Policy{ + Principals: []*v3rbacpb.Principal{principal}, + Permissions: []*v3rbacpb.Permission{permission}, + } + } + return policies, nil +} + +// translatePolicy translates SDK authorization policy in JSON format to two +// Envoy RBAC polices (deny and allow policy). If the policy cannot be parsed +// or is invalid, an error will be returned. +func translatePolicy(policyStr string) (*v3rbacpb.RBAC, *v3rbacpb.RBAC, error) { + var policy authorizationPolicy + if err := json.Unmarshal([]byte(policyStr), &policy); err != nil { + return nil, nil, fmt.Errorf("failed to unmarshal policy: %v", err) + } + if policy.Name == "" { + return nil, nil, fmt.Errorf(`"name" is not present`) + } + if len(policy.AllowRules) == 0 { + return nil, nil, fmt.Errorf(`"allow_rules" is not present`) + } + allowPolicies, err := parseRules(policy.AllowRules, policy.Name) + if err != nil { + return nil, nil, fmt.Errorf(`"allow_rules" %v`, err) + } + allowRBAC := &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: allowPolicies} + var denyRBAC *v3rbacpb.RBAC + if len(policy.DenyRules) > 0 { + denyPolicies, err := parseRules(policy.DenyRules, policy.Name) + if err != nil { + return nil, nil, fmt.Errorf(`"deny_rules" %v`, err) + } + denyRBAC = &v3rbacpb.RBAC{ + Action: v3rbacpb.RBAC_DENY, + Policies: denyPolicies, + } + } + return denyRBAC, allowRBAC, nil +} diff --git a/authz/rbac_translator_test.go b/authz/rbac_translator_test.go new file mode 100644 index 000000000000..425cae85b037 --- /dev/null +++ b/authz/rbac_translator_test.go @@ -0,0 +1,228 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package authz + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" +) + +func TestTranslatePolicy(t *testing.T) { + tests := map[string]struct { + authzPolicy string + wantErr string + wantDenyPolicy *v3rbacpb.RBAC + wantAllowPolicy *v3rbacpb.RBAC + }{ + "valid policy": { + authzPolicy: `{ + "name": "authz", + "deny_rules": [ + { + "name": "deny_policy_1", + "source": { + "principals":[ + "spiffe://foo.abc", + "spiffe://bar*", + "*baz", + "spiffe://abc.*.com" + ] + } + }], + "allow_rules": [ + { + "name": "allow_policy_1", + "source": { + "principals":["*"] + }, + "request": { + "paths": ["path-foo*"] + } + }, + { + "name": "allow_policy_2", + "request": { + "paths": [ + "path-bar", + "*baz" + ], + "headers": [ + { + "key": "key-1", + "values": ["foo", "*bar"] + }, + { + "key": "key-2", + "values": ["baz*"] + } + ] + } + }] + }`, + wantDenyPolicy: &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_DENY, Policies: map[string]*v3rbacpb.Policy{ + "authz_deny_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://foo.abc"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "spiffe://bar"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "spiffe://abc.*.com"}}}}}, + }}}}}, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}}, + }, + }}, + wantAllowPolicy: &v3rbacpb.RBAC{Action: v3rbacpb.RBAC_ALLOW, Policies: map[string]*v3rbacpb.Policy{ + "authz_allow_policy_1": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{ + Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: ""}}}}}, + }}}}}, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "path-foo"}}}}}}, + }}}}}}}}}, + }, + "authz_allow_policy_2": { + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}}, + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "path-bar"}}}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{ + UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{ + MatchPattern: &v3matcherpb.StringMatcher_Suffix{Suffix: "baz"}}}}}}, + }}}}, + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "foo"}}}}, + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-1", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "bar"}}}}, + }}}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{ + Header: &v3routepb.HeaderMatcher{ + Name: "key-2", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "baz"}}}}, + }}}}}}}}}}}}}, + }, + }}, + }, + "missing name field": { + authzPolicy: `{}`, + wantErr: `"name" is not present`, + }, + "invalid field type": { + authzPolicy: `{"name": 123}`, + wantErr: "failed to unmarshal policy", + }, + "missing allow rules field": { + authzPolicy: `{"name": "authz-foo"}`, + wantErr: `"allow_rules" is not present`, + wantDenyPolicy: nil, + wantAllowPolicy: nil, + }, + "missing rule name field": { + authzPolicy: `{ + "name": "authz-foo", + "allow_rules": [{}] + }`, + wantErr: `"allow_rules" 0: "name" is not present`, + }, + "missing header key": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a", "values": ["value-a"]}, {}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 1: "key" is not present`, + }, + "missing header values": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":"key-a"}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: "values" is not present`, + }, + "unsupported header": { + authzPolicy: `{ + "name": "authz", + "allow_rules": [{ + "name": "allow_policy_1", + "request": {"headers":[{"key":":method", "values":["GET"]}]} + }] + }`, + wantErr: `"allow_rules" 0: "headers" 0: unsupported "key" :method`, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + gotDenyPolicy, gotAllowPolicy, gotErr := translatePolicy(test.authzPolicy) + if gotErr != nil && !strings.HasPrefix(gotErr.Error(), test.wantErr) { + t.Fatalf("unexpected error\nwant:%v\ngot:%v", test.wantErr, gotErr) + } + if diff := cmp.Diff(gotDenyPolicy, test.wantDenyPolicy, protocmp.Transform()); diff != "" { + t.Fatalf("unexpected deny policy\ndiff (-want +got):\n%s", diff) + } + if diff := cmp.Diff(gotAllowPolicy, test.wantAllowPolicy, protocmp.Transform()); diff != "" { + t.Fatalf("unexpected allow policy\ndiff (-want +got):\n%s", diff) + } + }) + } +} diff --git a/balancer/grpclb/grpclb.go b/balancer/grpclb/grpclb.go index a43d8964119f..49d11d0d2e21 100644 --- a/balancer/grpclb/grpclb.go +++ b/balancer/grpclb/grpclb.go @@ -25,6 +25,7 @@ package grpclb import ( "context" "errors" + "fmt" "sync" "time" @@ -221,6 +222,7 @@ type lbBalancer struct { // when resolved address updates are received, and read in the goroutine // handling fallback. resolvedBackendAddrs []resolver.Address + connErr error // the last connection error } // regeneratePicker takes a snapshot of the balancer, and generates a picker from @@ -230,7 +232,7 @@ type lbBalancer struct { // Caller must hold lb.mu. func (lb *lbBalancer) regeneratePicker(resetDrop bool) { if lb.state == connectivity.TransientFailure { - lb.picker = &errPicker{err: balancer.ErrTransientFailure} + lb.picker = &errPicker{err: fmt.Errorf("all SubConns are in TransientFailure, last connection error: %v", lb.connErr)} return } @@ -336,6 +338,8 @@ func (lb *lbBalancer) UpdateSubConnState(sc balancer.SubConn, scs balancer.SubCo // When an address was removed by resolver, b called RemoveSubConn but // kept the sc's state in scStates. Remove state for this sc here. delete(lb.scStates, sc) + case connectivity.TransientFailure: + lb.connErr = scs.ConnectionError } // Force regenerate picker if // - this sc became ready from not-ready diff --git a/balancer/grpclb/grpclb_test.go b/balancer/grpclb/grpclb_test.go index 9cbb338c2415..d6275b657f90 100644 --- a/balancer/grpclb/grpclb_test.go +++ b/balancer/grpclb/grpclb_test.go @@ -1232,6 +1232,65 @@ func (s) TestGRPCLBPickFirst(t *testing.T) { } } +func (s) TestGRPCLBBackendConnectionErrorPropagation(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + + // Start up an LB which will tell the client to fall back + // right away. + tss, cleanup, err := newLoadBalancer(0, "", nil) + if err != nil { + t.Fatalf("failed to create new load balancer: %v", err) + } + defer cleanup() + + // Start a standalone backend, to be used during fallback. The creds + // are intentionally misconfigured in order to simulate failure of a + // security handshake. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + defer beLis.Close() + standaloneBEs := startBackends("arbitrary.invalid.name", true, beLis) + defer stopBackends(standaloneBEs) + + creds := serverNameCheckCreds{} + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cc, err := grpc.DialContext(ctx, r.Scheme()+":///"+beServerName, grpc.WithResolvers(r), + grpc.WithTransportCredentials(&creds), grpc.WithContextDialer(fakeNameDialer)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + defer cc.Close() + testC := testpb.NewTestServiceClient(cc) + + r.UpdateState(resolver.State{Addresses: []resolver.Address{{ + Addr: tss.lbAddr, + Type: resolver.GRPCLB, + ServerName: lbServerName, + }, { + Addr: beLis.Addr().String(), + Type: resolver.Backend, + }}}) + + // If https://github.com/grpc/grpc-go/blob/65cabd74d8e18d7347fecd414fa8d83a00035f5f/balancer/grpclb/grpclb_test.go#L103 + // changes, then expectedErrMsg may need to be updated. + const expectedErrMsg = "received unexpected server name" + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + tss.ls.fallbackNow() + wg.Done() + }() + if _, err := testC.EmptyCall(ctx, &testpb.Empty{}); err == nil || !strings.Contains(err.Error(), expectedErrMsg) { + t.Fatalf("%v.EmptyCall(_, _) = _, %v, want _, rpc error containing substring: %q", testC, err, expectedErrMsg) + } + wg.Wait() +} + type failPreRPCCred struct{} func (failPreRPCCred) GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) { diff --git a/balancer_conn_wrappers.go b/balancer_conn_wrappers.go index f1bb6dd30737..dd8397963974 100644 --- a/balancer_conn_wrappers.go +++ b/balancer_conn_wrappers.go @@ -43,7 +43,7 @@ type ccBalancerWrapper struct { cc *ClientConn balancerMu sync.Mutex // synchronizes calls to the balancer balancer balancer.Balancer - scBuffer *buffer.Unbounded + updateCh *buffer.Unbounded closed *grpcsync.Event done *grpcsync.Event @@ -54,7 +54,7 @@ type ccBalancerWrapper struct { func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.BuildOptions) *ccBalancerWrapper { ccb := &ccBalancerWrapper{ cc: cc, - scBuffer: buffer.NewUnbounded(), + updateCh: buffer.NewUnbounded(), closed: grpcsync.NewEvent(), done: grpcsync.NewEvent(), subConns: make(map[*acBalancerWrapper]struct{}), @@ -69,15 +69,26 @@ func newCCBalancerWrapper(cc *ClientConn, b balancer.Builder, bopts balancer.Bui func (ccb *ccBalancerWrapper) watcher() { for { select { - case t := <-ccb.scBuffer.Get(): - ccb.scBuffer.Load() + case t := <-ccb.updateCh.Get(): + ccb.updateCh.Load() if ccb.closed.HasFired() { break } - ccb.balancerMu.Lock() - su := t.(*scStateUpdate) - ccb.balancer.UpdateSubConnState(su.sc, balancer.SubConnState{ConnectivityState: su.state, ConnectionError: su.err}) - ccb.balancerMu.Unlock() + switch u := t.(type) { + case *scStateUpdate: + ccb.balancerMu.Lock() + ccb.balancer.UpdateSubConnState(u.sc, balancer.SubConnState{ConnectivityState: u.state, ConnectionError: u.err}) + ccb.balancerMu.Unlock() + case *acBalancerWrapper: + ccb.mu.Lock() + if ccb.subConns != nil { + delete(ccb.subConns, u) + ccb.cc.removeAddrConn(u.getAddrConn(), errConnDrain) + } + ccb.mu.Unlock() + default: + logger.Errorf("ccBalancerWrapper.watcher: unknown update %+v, type %T", t, t) + } case <-ccb.closed.Done(): } @@ -118,7 +129,7 @@ func (ccb *ccBalancerWrapper) handleSubConnStateChange(sc balancer.SubConn, s co if sc == nil { return } - ccb.scBuffer.Put(&scStateUpdate{ + ccb.updateCh.Put(&scStateUpdate{ sc: sc, state: s, err: err, @@ -159,17 +170,10 @@ func (ccb *ccBalancerWrapper) NewSubConn(addrs []resolver.Address, opts balancer } func (ccb *ccBalancerWrapper) RemoveSubConn(sc balancer.SubConn) { - acbw, ok := sc.(*acBalancerWrapper) - if !ok { - return - } - ccb.mu.Lock() - defer ccb.mu.Unlock() - if ccb.subConns == nil { - return - } - delete(ccb.subConns, acbw) - ccb.cc.removeAddrConn(acbw.getAddrConn(), errConnDrain) + // The RemoveSubConn() is handled in the run() goroutine, to avoid deadlock + // during switchBalancer() if the old balancer calls RemoveSubConn() in its + // Close(). + ccb.updateCh.Put(sc) } func (ccb *ccBalancerWrapper) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { diff --git a/balancer_switching_test.go b/balancer_switching_test.go index 2c6ed576620f..e9fee87d8f81 100644 --- a/balancer_switching_test.go +++ b/balancer_switching_test.go @@ -28,6 +28,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/roundrobin" "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" "google.golang.org/grpc/serviceconfig" @@ -531,6 +532,51 @@ func (s) TestSwitchBalancerGRPCLBWithGRPCLBNotRegistered(t *testing.T) { } } +const inlineRemoveSubConnBalancerName = "test-inline-remove-subconn-balancer" + +func init() { + stub.Register(inlineRemoveSubConnBalancerName, stub.BalancerFuncs{ + Close: func(data *stub.BalancerData) { + data.ClientConn.RemoveSubConn(&acBalancerWrapper{}) + }, + }) +} + +// Test that when switching to balancers, the old balancer calls RemoveSubConn +// in Close. +// +// This test is to make sure this close doesn't cause a deadlock. +func (s) TestSwitchBalancerOldRemoveSubConn(t *testing.T) { + r := manual.NewBuilderWithScheme("whatever") + cc, err := Dial(r.Scheme()+":///test.server", WithInsecure(), WithResolvers(r)) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + defer cc.Close() + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, fmt.Sprintf(`{"loadBalancingPolicy": "%v"}`, inlineRemoveSubConnBalancerName))}, nil) + // This service config update will switch balancer from + // "test-inline-remove-subconn-balancer" to "pick_first". The test balancer + // will be closed, which will call cc.RemoveSubConn() inline (this + // RemoveSubConn is not required by the API, but some balancers might do + // it). + // + // This is to make sure the cc.RemoveSubConn() from Close() doesn't cause a + // deadlock (e.g. trying to grab a mutex while it's already locked). + // + // Do it in a goroutine so this test will fail with a helpful message + // (though the goroutine will still leak). + done := make(chan struct{}) + go func() { + cc.updateResolverState(resolver.State{ServiceConfig: parseCfg(r, `{"loadBalancingPolicy": "pick_first"}`)}, nil) + close(done) + }() + select { + case <-time.After(defaultTestTimeout): + t.Fatalf("timeout waiting for updateResolverState to finish") + case <-done: + } +} + func parseCfg(r *manual.Resolver, s string) *serviceconfig.ParseResult { scpr := r.CC.ParseServiceConfig(s) if scpr.Err != nil { diff --git a/call_test.go b/call_test.go index abc4537ddb7d..8fdbc9b7eb7e 100644 --- a/call_test.go +++ b/call_test.go @@ -160,7 +160,7 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32) { config := &transport.ServerConfig{ MaxStreams: maxStreams, } - st, err := transport.NewServerTransport("http2", conn, config) + st, err := transport.NewServerTransport(conn, config) if err != nil { continue } diff --git a/clientconn.go b/clientconn.go index 0236c81c4d4a..b2bccfed136e 100644 --- a/clientconn.go +++ b/clientconn.go @@ -711,7 +711,12 @@ func (cc *ClientConn) switchBalancer(name string) { return } if cc.balancerWrapper != nil { + // Don't hold cc.mu while closing the balancers. The balancers may call + // methods that require cc.mu (e.g. cc.NewSubConn()). Holding the mutex + // would cause a deadlock in that case. + cc.mu.Unlock() cc.balancerWrapper.close() + cc.mu.Lock() } builder := balancer.Get(name) @@ -1424,26 +1429,14 @@ func (ac *addrConn) resetConnectBackoff() { ac.mu.Unlock() } -// getReadyTransport returns the transport if ac's state is READY. -// Otherwise it returns nil, false. -// If ac's state is IDLE, it will trigger ac to connect. -func (ac *addrConn) getReadyTransport() (transport.ClientTransport, bool) { +// getReadyTransport returns the transport if ac's state is READY or nil if not. +func (ac *addrConn) getReadyTransport() transport.ClientTransport { ac.mu.Lock() - if ac.state == connectivity.Ready && ac.transport != nil { - t := ac.transport - ac.mu.Unlock() - return t, true - } - var idle bool - if ac.state == connectivity.Idle { - idle = true - } - ac.mu.Unlock() - // Trigger idle ac to connect. - if idle { - ac.connect() + defer ac.mu.Unlock() + if ac.state == connectivity.Ready { + return ac.transport } - return nil, false + return nil } // tearDown starts to tear down the addrConn. diff --git a/clientconn_test.go b/clientconn_test.go index 6c61666b7efb..a50db9419c2a 100644 --- a/clientconn_test.go +++ b/clientconn_test.go @@ -735,16 +735,15 @@ func (s) TestClientUpdatesParamsAfterGoAway(t *testing.T) { time.Sleep(10 * time.Millisecond) cc.mu.RLock() v := cc.mkp.Time + cc.mu.RUnlock() if v == 20*time.Second { // Success - cc.mu.RUnlock() return } if ctx.Err() != nil { // Timeout t.Fatalf("cc.dopts.copts.Keepalive.Time = %v , want 20s", v) } - cc.mu.RUnlock() } } diff --git a/examples/go.sum b/examples/go.sum index a5b967336903..4b9be0bebf85 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,8 +1,12 @@ cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed h1:OZmjad4L3H8ncOIR8rnb5MREYqG8ixi5+WbeUsquF0c= @@ -38,6 +42,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= diff --git a/examples/helloworld/greeter_server/main.go b/examples/helloworld/greeter_server/main.go index 15604f9fc1f4..4d077db92cfb 100644 --- a/examples/helloworld/greeter_server/main.go +++ b/examples/helloworld/greeter_server/main.go @@ -50,6 +50,7 @@ func main() { } s := grpc.NewServer() pb.RegisterGreeterServer(s, &server{}) + log.Printf("server listening at %v", lis.Addr()) if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) } diff --git a/go.mod b/go.mod index 6eed9370b462..2f2cf1eb7669 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module google.golang.org/grpc go 1.11 require ( + github.com/cespare/xxhash v1.1.0 github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 github.com/envoyproxy/go-control-plane v0.9.9-0.20210512163311-63b5d3c536b0 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b diff --git a/go.sum b/go.sum index 51fd1436e38f..372b4ea3d201 100644 --- a/go.sum +++ b/go.sum @@ -2,9 +2,13 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMT cloud.google.com/go v0.34.0 h1:eOI3/cP2VTU6uZLDYAoic+eyzzB9YyGmJ7eIjl8rOPg= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1 h1:glEXhBS5PSLLv4IXzLA5yPRVX4bilULVyxxbrfOtDAk= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403 h1:cqQfy1jclcSy/FwLjemeg3SR1yaINm74aQyupQ0Bl8M= @@ -50,6 +54,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= diff --git a/internal/binarylog/sink.go b/internal/binarylog/sink.go index 7d7a3056b71e..c2fdd58b3198 100644 --- a/internal/binarylog/sink.go +++ b/internal/binarylog/sink.go @@ -69,7 +69,8 @@ type writerSink struct { func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { b, err := proto.Marshal(e) if err != nil { - grpclogLogger.Infof("binary logging: failed to marshal proto message: %v", err) + grpclogLogger.Errorf("binary logging: failed to marshal proto message: %v", err) + return err } hdr := make([]byte, 4) binary.BigEndian.PutUint32(hdr, uint32(len(b))) @@ -85,24 +86,27 @@ func (ws *writerSink) Write(e *pb.GrpcLogEntry) error { func (ws *writerSink) Close() error { return nil } type bufferedSink struct { - mu sync.Mutex - closer io.Closer - out Sink // out is built on buf. - buf *bufio.Writer // buf is kept for flush. - - writeStartOnce sync.Once - writeTicker *time.Ticker + mu sync.Mutex + closer io.Closer + out Sink // out is built on buf. + buf *bufio.Writer // buf is kept for flush. + flusherStarted bool + + writeTicker *time.Ticker + done chan struct{} } func (fs *bufferedSink) Write(e *pb.GrpcLogEntry) error { - // Start the write loop when Write is called. - fs.writeStartOnce.Do(fs.startFlushGoroutine) fs.mu.Lock() + defer fs.mu.Unlock() + if !fs.flusherStarted { + // Start the write loop when Write is called. + fs.startFlushGoroutine() + fs.flusherStarted = true + } if err := fs.out.Write(e); err != nil { - fs.mu.Unlock() return err } - fs.mu.Unlock() return nil } @@ -113,7 +117,12 @@ const ( func (fs *bufferedSink) startFlushGoroutine() { fs.writeTicker = time.NewTicker(bufFlushDuration) go func() { - for range fs.writeTicker.C { + for { + select { + case <-fs.done: + return + case <-fs.writeTicker.C: + } fs.mu.Lock() if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) @@ -124,10 +133,12 @@ func (fs *bufferedSink) startFlushGoroutine() { } func (fs *bufferedSink) Close() error { + fs.mu.Lock() + defer fs.mu.Unlock() if fs.writeTicker != nil { fs.writeTicker.Stop() } - fs.mu.Lock() + close(fs.done) if err := fs.buf.Flush(); err != nil { grpclogLogger.Warningf("failed to flush to Sink: %v", err) } @@ -137,7 +148,6 @@ func (fs *bufferedSink) Close() error { if err := fs.out.Close(); err != nil { grpclogLogger.Warningf("failed to close the Sink: %v", err) } - fs.mu.Unlock() return nil } @@ -155,5 +165,6 @@ func NewBufferedSink(o io.WriteCloser) Sink { closer: o, out: newWriterSink(bufW), buf: bufW, + done: make(chan struct{}), } } diff --git a/internal/grpcrand/grpcrand.go b/internal/grpcrand/grpcrand.go index 200b115ca209..740f83c2b766 100644 --- a/internal/grpcrand/grpcrand.go +++ b/internal/grpcrand/grpcrand.go @@ -31,26 +31,37 @@ var ( mu sync.Mutex ) +// Int implements rand.Int on the grpcrand global source. +func Int() int { + mu.Lock() + defer mu.Unlock() + return r.Int() +} + // Int63n implements rand.Int63n on the grpcrand global source. func Int63n(n int64) int64 { mu.Lock() - res := r.Int63n(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Int63n(n) } // Intn implements rand.Intn on the grpcrand global source. func Intn(n int) int { mu.Lock() - res := r.Intn(n) - mu.Unlock() - return res + defer mu.Unlock() + return r.Intn(n) } // Float64 implements rand.Float64 on the grpcrand global source. func Float64() float64 { mu.Lock() - res := r.Float64() - mu.Unlock() - return res + defer mu.Unlock() + return r.Float64() +} + +// Uint64 implements rand.Uint64 on the grpcrand global source. +func Uint64() uint64 { + mu.Lock() + defer mu.Unlock() + return r.Uint64() } diff --git a/internal/resolver/config_selector_test.go b/internal/resolver/config_selector_test.go index e5a50995df11..e1dae8bde27c 100644 --- a/internal/resolver/config_selector_test.go +++ b/internal/resolver/config_selector_test.go @@ -48,6 +48,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { retChan1 := make(chan *RPCConfig) retChan2 := make(chan *RPCConfig) + defer close(retChan1) + defer close(retChan2) one := 1 two := 2 @@ -55,8 +57,8 @@ func (s) TestSafeConfigSelector(t *testing.T) { resp1 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &one}} resp2 := &RPCConfig{MethodConfig: serviceconfig.MethodConfig{MaxReqSize: &two}} - cs1Called := make(chan struct{}) - cs2Called := make(chan struct{}) + cs1Called := make(chan struct{}, 1) + cs2Called := make(chan struct{}, 1) cs1 := &fakeConfigSelector{ selectConfig: func(r RPCInfo) (*RPCConfig, error) { diff --git a/internal/status/status.go b/internal/status/status.go index 710223b8ded0..e5c6513edd13 100644 --- a/internal/status/status.go +++ b/internal/status/status.go @@ -97,7 +97,7 @@ func (s *Status) Err() error { if s.Code() == codes.OK { return nil } - return &Error{e: s.Proto()} + return &Error{s: s} } // WithDetails returns a new status with the provided details messages appended to the status. @@ -136,19 +136,23 @@ func (s *Status) Details() []interface{} { return details } +func (s *Status) String() string { + return fmt.Sprintf("rpc error: code = %s desc = %s", s.Code(), s.Message()) +} + // Error wraps a pointer of a status proto. It implements error and Status, // and a nil *Error should never be returned by this package. type Error struct { - e *spb.Status + s *Status } func (e *Error) Error() string { - return fmt.Sprintf("rpc error: code = %s desc = %s", codes.Code(e.e.GetCode()), e.e.GetMessage()) + return e.s.String() } // GRPCStatus returns the Status represented by se. func (e *Error) GRPCStatus() *Status { - return FromProto(e.e) + return e.s } // Is implements future error.Is functionality. @@ -158,5 +162,5 @@ func (e *Error) Is(target error) bool { if !ok { return false } - return proto.Equal(e.e, tse.e) + return proto.Equal(e.s.s, tse.s.s) } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 119f01e3ebcf..fa41fec903ee 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -24,6 +24,7 @@ import ( "io" "math" "net" + "net/http" "strconv" "strings" "sync" @@ -241,7 +242,15 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts // and passed to the credential handshaker. This makes it possible for // address specific arbitrary data to reach the credential handshaker. connectCtx = icredentials.NewClientHandshakeInfoContext(connectCtx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes}) - conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, conn) + rawConn := conn + // Pull the deadline from the connectCtx, which will be used for + // timeouts in the authentication protocol handshake. Can ignore the + // boolean as the deadline will return the zero value, which will make + // the conn not timeout on I/O operations. + deadline, _ := connectCtx.Deadline() + rawConn.SetDeadline(deadline) + conn, authInfo, err = transportCreds.ClientHandshake(connectCtx, addr.ServerName, rawConn) + rawConn.SetDeadline(time.Time{}) if err != nil { return nil, connectionErrorf(isTemporary(err), err, "transport: authentication handshake failed: %v", err) } @@ -607,26 +616,39 @@ func (t *http2Client) getCallAuthData(ctx context.Context, audience string, call return callAuthData, nil } -// PerformedIOError wraps an error to indicate IO may have been performed -// before the error occurred. -type PerformedIOError struct { +// NewStreamError wraps an error and reports additional information. +type NewStreamError struct { Err error + + DoNotRetry bool + PerformedIO bool } -// Error implements error. -func (p PerformedIOError) Error() string { - return p.Err.Error() +func (e NewStreamError) Error() string { + return e.Err.Error() } // NewStream creates a stream and registers it into the transport as "active" -// streams. +// streams. All non-nil errors returned will be *NewStreamError. func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Stream, err error) { + defer func() { + if err != nil { + nse, ok := err.(*NewStreamError) + if !ok { + nse = &NewStreamError{Err: err} + } + if len(t.perRPCCreds) > 0 || callHdr.Creds != nil { + // We may have performed I/O in the per-RPC creds callback, so do not + // allow transparent retry. + nse.PerformedIO = true + } + err = nse + } + }() ctx = peer.NewContext(ctx, t.getPeer()) headerFields, err := t.createHeaderFields(ctx, callHdr) if err != nil { - // We may have performed I/O in the per-RPC creds callback, so do not - // allow transparent retry. - return nil, PerformedIOError{err} + return nil, err } s := t.newStream(ctx, callHdr) cleanup := func(err error) { @@ -732,7 +754,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea break } if hdrListSizeErr != nil { - return nil, hdrListSizeErr + return nil, &NewStreamError{Err: hdrListSizeErr, DoNotRetry: true} } firstTry = false select { @@ -877,12 +899,18 @@ func (t *http2Client) Close(err error) { // Append info about previous goaways if there were any, since this may be important // for understanding the root cause for this connection to be closed. _, goAwayDebugMessage := t.GetGoAwayReason() + + var st *status.Status if len(goAwayDebugMessage) > 0 { - err = fmt.Errorf("closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + st = status.Newf(codes.Unavailable, "closing transport due to: %v, received prior goaway: %v", err, goAwayDebugMessage) + err = st.Err() + } else { + st = status.New(codes.Unavailable, err.Error()) } + // Notify all active streams. for _, s := range streams { - t.closeStream(s, err, false, http2.ErrCodeNo, status.New(codes.Unavailable, err.Error()), nil, false) + t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } if t.statsHandler != nil { connEnd := &stats.ConnEnd{ @@ -1220,7 +1248,11 @@ func (t *http2Client) setGoAwayReason(f *http2.GoAwayFrame) { t.goAwayReason = GoAwayTooManyPings } } - t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %v", f.ErrCode, string(f.DebugData())) + if len(f.DebugData()) == 0 { + t.goAwayDebugMessage = fmt.Sprintf("code: %s", f.ErrCode) + } else { + t.goAwayDebugMessage = fmt.Sprintf("code: %s, debug data: %q", f.ErrCode, string(f.DebugData())) + } } func (t *http2Client) GetGoAwayReason() (GoAwayReason, string) { @@ -1266,29 +1298,40 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { // that the peer is speaking gRPC and we are in gRPC mode. isGRPC = !initialHeader mdata = make(map[string][]string) - contentTypeErr string + contentTypeErr = "malformed header: missing HTTP content-type" grpcMessage string statusGen *status.Status - - httpStatus string - rawStatus string + httpStatusCode *int + httpStatusErr string + rawStatusCode = codes.Unknown // headerError is set if an error is encountered while parsing the headers headerError string ) + if initialHeader { + httpStatusErr = "malformed header: missing HTTP status" + } + for _, hf := range frame.Fields { switch hf.Name { case "content-type": if _, validContentType := grpcutil.ContentSubtype(hf.Value); !validContentType { - contentTypeErr = fmt.Sprintf("transport: received the unexpected content-type %q", hf.Value) + contentTypeErr = fmt.Sprintf("transport: received unexpected content-type %q", hf.Value) break } + contentTypeErr = "" mdata[hf.Name] = append(mdata[hf.Name], hf.Value) isGRPC = true case "grpc-encoding": s.recvCompress = hf.Value case "grpc-status": - rawStatus = hf.Value + code, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + rawStatusCode = codes.Code(uint32(code)) case "grpc-message": grpcMessage = decodeGrpcMessage(hf.Value) case "grpc-status-details-bin": @@ -1298,7 +1341,27 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { headerError = fmt.Sprintf("transport: malformed grpc-status-details-bin: %v", err) } case ":status": - httpStatus = hf.Value + if hf.Value == "200" { + httpStatusErr = "" + statusCode := 200 + httpStatusCode = &statusCode + break + } + + c, err := strconv.ParseInt(hf.Value, 10, 32) + if err != nil { + se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) + t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) + return + } + statusCode := int(c) + httpStatusCode = &statusCode + + httpStatusErr = fmt.Sprintf( + "unexpected HTTP status code received from server: %d (%s)", + statusCode, + http.StatusText(statusCode), + ) default: if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { break @@ -1313,30 +1376,25 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } } - if !isGRPC { - var ( - code = codes.Internal // when header does not include HTTP status, return INTERNAL - httpStatusCode int - ) - - if httpStatus != "" { - c, err := strconv.ParseInt(httpStatus, 10, 32) - if err != nil { - se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) - t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) - return - } - httpStatusCode = int(c) + if !isGRPC || httpStatusErr != "" { + var code = codes.Internal // when header does not include HTTP status, return INTERNAL + if httpStatusCode != nil { var ok bool - code, ok = HTTPStatusConvTab[httpStatusCode] + code, ok = HTTPStatusConvTab[*httpStatusCode] if !ok { code = codes.Unknown } } - + var errs []string + if httpStatusErr != "" { + errs = append(errs, httpStatusErr) + } + if contentTypeErr != "" { + errs = append(errs, contentTypeErr) + } // Verify the HTTP response is a 200. - se := status.New(code, constructHTTPErrMsg(&httpStatusCode, contentTypeErr)) + se := status.New(code, strings.Join(errs, "; ")) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) return } @@ -1393,16 +1451,6 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } if statusGen == nil { - rawStatusCode := codes.Unknown - if rawStatus != "" { - code, err := strconv.ParseInt(rawStatus, 10, 32) - if err != nil { - se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) - t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) - return - } - rawStatusCode = codes.Code(uint32(code)) - } statusGen = status.New(rawStatusCode, grpcMessage) } diff --git a/internal/transport/http2_server.go b/internal/transport/http2_server.go index 21a3c8526158..e3799d50aa71 100644 --- a/internal/transport/http2_server.go +++ b/internal/transport/http2_server.go @@ -125,9 +125,14 @@ type http2Server struct { connectionID uint64 } -// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is -// returned if something goes wrong. -func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { +// NewServerTransport creates a http2 transport with conn and configuration +// options from config. +// +// It returns a non-nil transport and a nil error on success. On failure, it +// returns a non-nil transport and a nil-error. For a special case where the +// underlying conn gets closed before the client preface could be read, it +// returns a nil transport and a nil error. +func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, err error) { writeBufSize := config.WriteBufferSize readBufSize := config.ReadBufferSize maxHeaderListSize := defaultServerMaxHeaderListSize @@ -266,6 +271,13 @@ func newHTTP2Server(conn net.Conn, config *ServerConfig) (_ ServerTransport, err // Check the validity of client preface. preface := make([]byte, len(clientPreface)) if _, err := io.ReadFull(t.conn, preface); err != nil { + // In deployments where a gRPC server runs behind a cloud load balancer + // which performs regular TCP level health checks, the connection is + // closed immediately by the latter. Skipping the error here will help + // reduce log clutter. + if err == io.EOF { + return nil, nil + } return nil, connectionErrorf(false, err, "transport: http2Server.HandleStreams failed to receive the preface from client: %v", err) } if !bytes.Equal(preface, clientPreface) { diff --git a/internal/transport/http_util.go b/internal/transport/http_util.go index 15d775fca3cc..d8247bcdf692 100644 --- a/internal/transport/http_util.go +++ b/internal/transport/http_util.go @@ -173,26 +173,6 @@ func decodeGRPCStatusDetails(rawDetails string) (*status.Status, error) { return status.FromProto(st), nil } -// constructErrMsg constructs error message to be returned in HTTP fallback mode. -// Format: HTTP status code and its corresponding message + content-type error message. -func constructHTTPErrMsg(httpStatus *int, contentTypeErr string) string { - var errMsgs []string - - if httpStatus == nil { - errMsgs = append(errMsgs, "malformed header: missing HTTP status") - } else { - errMsgs = append(errMsgs, fmt.Sprintf("%s: HTTP status code %d", http.StatusText(*(httpStatus)), *httpStatus)) - } - - if contentTypeErr == "" { - errMsgs = append(errMsgs, "transport: missing content-type field") - } else { - errMsgs = append(errMsgs, contentTypeErr) - } - - return strings.Join(errMsgs, "; ") -} - type timeoutUnit uint8 const ( diff --git a/internal/transport/networktype/networktype.go b/internal/transport/networktype/networktype.go index 96967428b515..7bb53cff1011 100644 --- a/internal/transport/networktype/networktype.go +++ b/internal/transport/networktype/networktype.go @@ -17,7 +17,7 @@ */ // Package networktype declares the network type to be used in the default -// dailer. Attribute of a resolver.Address. +// dialer. Attribute of a resolver.Address. package networktype import ( diff --git a/internal/transport/transport.go b/internal/transport/transport.go index 3ab945aa4d20..14198126457b 100644 --- a/internal/transport/transport.go +++ b/internal/transport/transport.go @@ -532,12 +532,6 @@ type ServerConfig struct { HeaderTableSize *uint32 } -// NewServerTransport creates a ServerTransport with conn or non-nil error -// if it fails. -func NewServerTransport(protocol string, conn net.Conn, config *ServerConfig) (ServerTransport, error) { - return newHTTP2Server(conn, config) -} - // ConnectOptions covers all relevant options for communicating with the server. type ConnectOptions struct { // UserAgent is the application user agent. diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index 3aef86277140..28cace0ba5d0 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -323,7 +323,7 @@ func (s *server) start(t *testing.T, port int, serverConfig *ServerConfig, ht hT if err != nil { return } - transport, err := NewServerTransport("http2", conn, serverConfig) + transport, err := NewServerTransport(conn, serverConfig) if err != nil { return } @@ -780,7 +780,7 @@ func (s) TestGracefulClose(t *testing.T) { go func() { defer wg.Done() str, err := ct.NewStream(ctx, &CallHdr{}) - if err == ErrConnClosing { + if err != nil && err.(*NewStreamError).Err == ErrConnClosing { return } else if err != nil { t.Errorf("_.NewStream(_, _) = _, %v, want _, %v", err, ErrConnClosing) @@ -1978,6 +1978,31 @@ func (s) TestClientHandshakeInfo(t *testing.T) { } func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { + testStream := func() *Stream { + return &Stream{ + done: make(chan struct{}), + headerChan: make(chan struct{}), + buf: &recvBuffer{ + c: make(chan recvMsg), + mu: sync.Mutex{}, + }, + } + } + + testClient := func(ts *Stream) *http2Client { + return &http2Client{ + mu: sync.Mutex{}, + activeStreams: map[uint32]*Stream{ + 0: ts, + }, + controlBuf: &controlBuffer{ + ch: make(chan struct{}), + done: make(chan struct{}), + list: &itemList{}, + }, + } + } + for _, test := range []struct { name string // input @@ -1991,17 +2016,32 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { Fields: []hpack.HeaderField{ {Name: "content-type", Value: "application/grpc"}, {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, }, }, // no error wantStatus: status.New(codes.OK, ""), }, + { + name: "missing content-type header", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "200"}, + }, + }, + wantStatus: status.New( + codes.Unknown, + "malformed header: missing HTTP content-type", + ), + }, { name: "invalid grpc status header field", metaHeaderFrame: &http2.MetaHeadersFrame{ Fields: []hpack.HeaderField{ {Name: "content-type", Value: "application/grpc"}, {Name: "grpc-status", Value: "xxxx"}, + {Name: ":status", Value: "200"}, }, }, wantStatus: status.New( @@ -2018,7 +2058,7 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { }, wantStatus: status.New( codes.Internal, - ": HTTP status code 0; transport: received the unexpected content-type \"application/json\"", + "malformed header: missing HTTP status; transport: received unexpected content-type \"application/json\"", ), }, { @@ -2045,27 +2085,56 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { "peer header list size exceeded limit", ), }, + { + name: "bad status in grpc mode", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + {Name: "grpc-status", Value: "0"}, + {Name: ":status", Value: "504"}, + }, + }, + wantStatus: status.New( + codes.Unavailable, + "unexpected HTTP status code received from server: 504 (Gateway Timeout)", + ), + }, + { + name: "missing http status", + metaHeaderFrame: &http2.MetaHeadersFrame{ + Fields: []hpack.HeaderField{ + {Name: "content-type", Value: "application/grpc"}, + }, + }, + wantStatus: status.New( + codes.Internal, + "malformed header: missing HTTP status", + ), + }, } { + t.Run(test.name, func(t *testing.T) { - ts := &Stream{ - done: make(chan struct{}), - headerChan: make(chan struct{}), - buf: &recvBuffer{ - c: make(chan recvMsg), - mu: sync.Mutex{}, + ts := testStream() + s := testClient(ts) + + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ + FrameHeader: http2.FrameHeader{ + StreamID: 0, }, } - s := &http2Client{ - mu: sync.Mutex{}, - activeStreams: map[uint32]*Stream{ - 0: ts, - }, - controlBuf: &controlBuffer{ - ch: make(chan struct{}), - done: make(chan struct{}), - list: &itemList{}, - }, + + s.operateHeaders(test.metaHeaderFrame) + + got := ts.status + want := test.wantStatus + if got.Code() != want.Code() || got.Message() != want.Message() { + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) } + }) + t.Run(fmt.Sprintf("%s-end_stream", test.name), func(t *testing.T) { + ts := testStream() + s := testClient(ts) + test.metaHeaderFrame.HeadersFrame = &http2.HeadersFrame{ FrameHeader: http2.FrameHeader{ StreamID: 0, @@ -2078,7 +2147,7 @@ func (s) TestClientDecodeHeaderStatusErr(t *testing.T) { got := ts.status want := test.wantStatus if got.Code() != want.Code() || got.Message() != want.Message() { - t.Fatalf("operateHeaders(%v); status = %v; want %v", test.metaHeaderFrame, got, want) + t.Fatalf("operateHeaders(%v); status = \ngot: %s\nwant: %s", test.metaHeaderFrame, got, want) } }) } diff --git a/internal/xds/bootstrap.go b/internal/xds/bootstrap.go index a3f80d8f2496..1d74ab46a114 100644 --- a/internal/xds/bootstrap.go +++ b/internal/xds/bootstrap.go @@ -65,11 +65,32 @@ type BootstrapOptions struct { // completed successfully. It is the responsibility of the caller to invoke the // cleanup function at the end of the test. func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { + bootstrapContents, err := BootstrapContents(opts) + if err != nil { + return nil, err + } f, err := ioutil.TempFile("", "test_xds_bootstrap_*") if err != nil { return nil, fmt.Errorf("failed to created bootstrap file: %v", err) } + if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { + return nil, fmt.Errorf("failed to created bootstrap file: %v", err) + } + logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) + + origBootstrapFileName := env.BootstrapFileName + env.BootstrapFileName = f.Name() + return func() { + os.Remove(f.Name()) + env.BootstrapFileName = origBootstrapFileName + }, nil +} + +// BootstrapContents returns the contents to go into a bootstrap file, +// environment, or configuration passed to +// xds.NewXDSResolverWithConfigForTesting. +func BootstrapContents(opts BootstrapOptions) ([]byte, error) { cfg := &bootstrapConfig{ XdsServers: []server{ { @@ -100,17 +121,7 @@ func SetupBootstrapFile(opts BootstrapOptions) (func(), error) { if err != nil { return nil, fmt.Errorf("failed to created bootstrap file: %v", err) } - if err := ioutil.WriteFile(f.Name(), bootstrapContents, 0644); err != nil { - return nil, fmt.Errorf("failed to created bootstrap file: %v", err) - } - logger.Infof("Created bootstrap file at %q with contents: %s\n", f.Name(), bootstrapContents) - - origBootstrapFileName := env.BootstrapFileName - env.BootstrapFileName = f.Name() - return func() { - os.Remove(f.Name()) - env.BootstrapFileName = origBootstrapFileName - }, nil + return bootstrapContents, nil } type bootstrapConfig struct { diff --git a/internal/xds/env/env.go b/internal/xds/env/env.go index db9ac93b968c..448fd63c21eb 100644 --- a/internal/xds/env/env.go +++ b/internal/xds/env/env.go @@ -39,9 +39,7 @@ const ( // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileContentEnv = "GRPC_XDS_BOOTSTRAP_CONFIG" - circuitBreakingSupportEnv = "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" - timeoutSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" - faultInjectionSupportEnv = "GRPC_XDS_EXPERIMENTAL_FAULT_INJECTION" + ringHashSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" clientSideSecuritySupportEnv = "GRPC_XDS_EXPERIMENTAL_SECURITY_SUPPORT" aggregateAndDNSSupportEnv = "GRPC_XDS_EXPERIMENTAL_ENABLE_AGGREGATE_AND_LOGICAL_DNS_CLUSTER" @@ -52,28 +50,20 @@ const ( var ( // BootstrapFileName holds the name of the file which contains xDS bootstrap // configuration. Users can specify the location of the bootstrap file by - // setting the environment variable "GRPC_XDS_BOOSTRAP". + // setting the environment variable "GRPC_XDS_BOOTSTRAP". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileName = os.Getenv(BootstrapFileNameEnv) // BootstrapFileContent holds the content of the xDS bootstrap // configuration. Users can specify the bootstrap config by - // setting the environment variable "GRPC_XDS_BOOSTRAP_CONFIG". + // setting the environment variable "GRPC_XDS_BOOTSTRAP_CONFIG". // // When both bootstrap FileName and FileContent are set, FileName is used. BootstrapFileContent = os.Getenv(BootstrapFileContentEnv) - - // CircuitBreakingSupport indicates whether circuit breaking support is - // enabled, which can be disabled by setting the environment variable - // "GRPC_XDS_EXPERIMENTAL_CIRCUIT_BREAKING" to "false". - CircuitBreakingSupport = !strings.EqualFold(os.Getenv(circuitBreakingSupportEnv), "false") - // TimeoutSupport indicates whether support for max_stream_duration in - // route actions is enabled. This can be disabled by setting the - // environment variable "GRPC_XDS_EXPERIMENTAL_ENABLE_TIMEOUT" to "false". - TimeoutSupport = !strings.EqualFold(os.Getenv(timeoutSupportEnv), "false") - // FaultInjectionSupport is used to control both fault injection and HTTP - // filter support. - FaultInjectionSupport = !strings.EqualFold(os.Getenv(faultInjectionSupportEnv), "false") + // RingHashSupport indicates whether ring hash support is enabled, which can + // be enabled by setting the environment variable + // "GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH" to "true". + RingHashSupport = strings.EqualFold(os.Getenv(ringHashSupportEnv), "true") // ClientSideSecuritySupport is used to control processing of security // configuration on the client-side. // diff --git a/internal/xds/matcher/matcher_header.go b/internal/xds/matcher/matcher_header.go index f9c0322179e8..35a22adadcf2 100644 --- a/internal/xds/matcher/matcher_header.go +++ b/internal/xds/matcher/matcher_header.go @@ -27,10 +27,10 @@ import ( "google.golang.org/grpc/metadata" ) -// HeaderMatcherInterface is an interface for header matchers. These are +// HeaderMatcher is an interface for header matchers. These are // documented in (EnvoyProxy link here?). These matchers will match on different // aspects of HTTP header name/value pairs. -type HeaderMatcherInterface interface { +type HeaderMatcher interface { Match(metadata.MD) bool String() string } @@ -203,13 +203,42 @@ func (hsm *HeaderSuffixMatcher) String() string { return fmt.Sprintf("headerSuffix:%v:%v", hsm.key, hsm.suffix) } +// HeaderContainsMatcher matches on whether the header value contains the +// value passed into this struct. +type HeaderContainsMatcher struct { + key string + contains string +} + +// NewHeaderContainsMatcher returns a new HeaderContainsMatcher. key is the HTTP +// Header key to match on, and contains is the value that the header should +// should contain for a successful match. An empty contains string does not +// work, use HeaderPresentMatcher in that case. +func NewHeaderContainsMatcher(key string, contains string) *HeaderContainsMatcher { + return &HeaderContainsMatcher{key: key, contains: contains} +} + +// Match returns whether the passed in HTTP Headers match according to the +// HeaderContainsMatcher. +func (hcm *HeaderContainsMatcher) Match(md metadata.MD) bool { + v, ok := mdValuesFromOutgoingCtx(md, hcm.key) + if !ok { + return false + } + return strings.Contains(v, hcm.contains) +} + +func (hcm *HeaderContainsMatcher) String() string { + return fmt.Sprintf("headerContains:%v%v", hcm.key, hcm.contains) +} + // InvertMatcher inverts the match result of the underlying header matcher. type InvertMatcher struct { - m HeaderMatcherInterface + m HeaderMatcher } // NewInvertMatcher returns a new InvertMatcher. -func NewInvertMatcher(m HeaderMatcherInterface) *InvertMatcher { +func NewInvertMatcher(m HeaderMatcher) *InvertMatcher { return &InvertMatcher{m: m} } diff --git a/internal/xds/matcher/matcher_header_test.go b/internal/xds/matcher/matcher_header_test.go index 911e7bcfaca1..902b8112e889 100644 --- a/internal/xds/matcher/matcher_header_test.go +++ b/internal/xds/matcher/matcher_header_test.go @@ -309,7 +309,7 @@ func TestHeaderSuffixMatcherMatch(t *testing.T) { func TestInvertMatcherMatch(t *testing.T) { tests := []struct { name string - m HeaderMatcherInterface + m HeaderMatcher md metadata.MD }{ { diff --git a/internal/xds/rbac/matchers.go b/internal/xds/rbac/matchers.go new file mode 100644 index 000000000000..25d1cc0e8b98 --- /dev/null +++ b/internal/xds/rbac/matchers.go @@ -0,0 +1,416 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "errors" + "fmt" + "net" + "regexp" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3route_componentspb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + internalmatcher "google.golang.org/grpc/internal/xds/matcher" +) + +// matcher is an interface that takes data about incoming RPC's and returns +// whether it matches with whatever matcher implements this interface. +type matcher interface { + match(data *rpcData) bool +} + +// policyMatcher helps determine whether an incoming RPC call matches a policy. +// A policy is a logical role (e.g. Service Admin), which is comprised of +// permissions and principals. A principal is an identity (or identities) for a +// downstream subject which are assigned the policy (role), and a permission is +// an action(s) that a principal(s) can take. A policy matches if both a +// permission and a principal match, which will be determined by the child or +// permissions and principal matchers. policyMatcher implements the matcher +// interface. +type policyMatcher struct { + permissions *orMatcher + principals *orMatcher +} + +func newPolicyMatcher(policy *v3rbacpb.Policy) (*policyMatcher, error) { + permissions, err := matchersFromPermissions(policy.Permissions) + if err != nil { + return nil, err + } + principals, err := matchersFromPrincipals(policy.Principals) + if err != nil { + return nil, err + } + return &policyMatcher{ + permissions: &orMatcher{matchers: permissions}, + principals: &orMatcher{matchers: principals}, + }, nil +} + +func (pm *policyMatcher) match(data *rpcData) bool { + // A policy matches if and only if at least one of its permissions match the + // action taking place AND at least one if its principals match the + // downstream peer. + return pm.permissions.match(data) && pm.principals.match(data) +} + +// matchersFromPermissions takes a list of permissions (can also be +// a single permission, e.g. from a not matcher which is logically !permission) +// and returns a list of matchers which correspond to that permission. This will +// be called in many instances throughout the initial construction of the RBAC +// engine from the AND and OR matchers and also from the NOT matcher. +func matchersFromPermissions(permissions []*v3rbacpb.Permission) ([]matcher, error) { + var matchers []matcher + for _, permission := range permissions { + switch permission.GetRule().(type) { + case *v3rbacpb.Permission_AndRules: + mList, err := matchersFromPermissions(permission.GetAndRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Permission_OrRules: + mList, err := matchersFromPermissions(permission.GetOrRules().Rules) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Permission_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Permission_Header: + m, err := newHeaderMatcher(permission.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_UrlPath: + m, err := newURLPathMatcher(permission.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationIp: + m, err := newDestinationIPMatcher(permission.GetDestinationIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Permission_DestinationPort: + matchers = append(matchers, newPortMatcher(permission.GetDestinationPort())) + case *v3rbacpb.Permission_NotRule: + mList, err := matchersFromPermissions([]*v3rbacpb.Permission{{Rule: permission.GetNotRule().Rule}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Permission_Metadata: + // Not supported in gRPC RBAC currently - a permission typed as + // Metadata in the initial config will be a no-op. + case *v3rbacpb.Permission_RequestedServerName: + // Not supported in gRPC RBAC currently - a permission typed as + // requested server name in the initial config will be a no-op. + } + } + return matchers, nil +} + +func matchersFromPrincipals(principals []*v3rbacpb.Principal) ([]matcher, error) { + var matchers []matcher + for _, principal := range principals { + switch principal.GetIdentifier().(type) { + case *v3rbacpb.Principal_AndIds: + mList, err := matchersFromPrincipals(principal.GetAndIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &andMatcher{matchers: mList}) + case *v3rbacpb.Principal_OrIds: + mList, err := matchersFromPrincipals(principal.GetOrIds().Ids) + if err != nil { + return nil, err + } + matchers = append(matchers, &orMatcher{matchers: mList}) + case *v3rbacpb.Principal_Any: + matchers = append(matchers, &alwaysMatcher{}) + case *v3rbacpb.Principal_Authenticated_: + authenticatedMatcher, err := newAuthenticatedMatcher(principal.GetAuthenticated()) + if err != nil { + return nil, err + } + matchers = append(matchers, authenticatedMatcher) + case *v3rbacpb.Principal_DirectRemoteIp: + m, err := newSourceIPMatcher(principal.GetDirectRemoteIp()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_Header: + // Do we need an error here? + m, err := newHeaderMatcher(principal.GetHeader()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_UrlPath: + m, err := newURLPathMatcher(principal.GetUrlPath()) + if err != nil { + return nil, err + } + matchers = append(matchers, m) + case *v3rbacpb.Principal_NotId: + mList, err := matchersFromPrincipals([]*v3rbacpb.Principal{{Identifier: principal.GetNotId().Identifier}}) + if err != nil { + return nil, err + } + matchers = append(matchers, ¬Matcher{matcherToNot: mList[0]}) + case *v3rbacpb.Principal_SourceIp: + // The source ip principal identifier is deprecated. Thus, a + // principal typed as a source ip in the identifier will be a no-op. + // The config should use DirectRemoteIp instead. + case *v3rbacpb.Principal_RemoteIp: + // Not supported in gRPC RBAC currently - a principal typed as + // Remote Ip in the initial config will be a no-op. + case *v3rbacpb.Principal_Metadata: + // Not supported in gRPC RBAC currently - a principal typed as + // Metadata in the initial config will be a no-op. + } + } + return matchers, nil +} + +// orMatcher is a matcher where it successfully matches if one of it's +// children successfully match. It also logically represents a principal or +// permission, but can also be it's own entity further down the tree of +// matchers. orMatcher implements the matcher interface. +type orMatcher struct { + matchers []matcher +} + +func (om *orMatcher) match(data *rpcData) bool { + // Range through child matchers and pass in data about incoming RPC, and + // only one child matcher has to match to be logically successful. + for _, m := range om.matchers { + if m.match(data) { + return true + } + } + return false +} + +// andMatcher is a matcher that is successful if every child matcher +// matches. andMatcher implements the matcher interface. +type andMatcher struct { + matchers []matcher +} + +func (am *andMatcher) match(data *rpcData) bool { + for _, m := range am.matchers { + if !m.match(data) { + return false + } + } + return true +} + +// alwaysMatcher is a matcher that will always match. This logically +// represents an any rule for a permission or a principal. alwaysMatcher +// implements the matcher interface. +type alwaysMatcher struct { +} + +func (am *alwaysMatcher) match(data *rpcData) bool { + return true +} + +// notMatcher is a matcher that nots an underlying matcher. notMatcher +// implements the matcher interface. +type notMatcher struct { + matcherToNot matcher +} + +func (nm *notMatcher) match(data *rpcData) bool { + return !nm.matcherToNot.match(data) +} + +// headerMatcher is a matcher that matches on incoming HTTP Headers present +// in the incoming RPC. headerMatcher implements the matcher interface. +type headerMatcher struct { + matcher internalmatcher.HeaderMatcher +} + +func newHeaderMatcher(headerMatcherConfig *v3route_componentspb.HeaderMatcher) (*headerMatcher, error) { + var m internalmatcher.HeaderMatcher + switch headerMatcherConfig.HeaderMatchSpecifier.(type) { + case *v3route_componentspb.HeaderMatcher_ExactMatch: + m = internalmatcher.NewHeaderExactMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetExactMatch()) + case *v3route_componentspb.HeaderMatcher_SafeRegexMatch: + regex, err := regexp.Compile(headerMatcherConfig.GetSafeRegexMatch().Regex) + if err != nil { + return nil, err + } + m = internalmatcher.NewHeaderRegexMatcher(headerMatcherConfig.Name, regex) + case *v3route_componentspb.HeaderMatcher_RangeMatch: + m = internalmatcher.NewHeaderRangeMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetRangeMatch().Start, headerMatcherConfig.GetRangeMatch().End) + case *v3route_componentspb.HeaderMatcher_PresentMatch: + m = internalmatcher.NewHeaderPresentMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPresentMatch()) + case *v3route_componentspb.HeaderMatcher_PrefixMatch: + m = internalmatcher.NewHeaderPrefixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetPrefixMatch()) + case *v3route_componentspb.HeaderMatcher_SuffixMatch: + m = internalmatcher.NewHeaderSuffixMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetSuffixMatch()) + case *v3route_componentspb.HeaderMatcher_ContainsMatch: + m = internalmatcher.NewHeaderContainsMatcher(headerMatcherConfig.Name, headerMatcherConfig.GetContainsMatch()) + default: + return nil, errors.New("unknown header matcher type") + } + if headerMatcherConfig.InvertMatch { + m = internalmatcher.NewInvertMatcher(m) + } + return &headerMatcher{matcher: m}, nil +} + +func (hm *headerMatcher) match(data *rpcData) bool { + return hm.matcher.Match(data.md) +} + +// urlPathMatcher matches on the URL Path of the incoming RPC. In gRPC, this +// logically maps to the full method name the RPC is calling on the server side. +// urlPathMatcher implements the matcher interface. +type urlPathMatcher struct { + stringMatcher internalmatcher.StringMatcher +} + +func newURLPathMatcher(pathMatcher *v3matcherpb.PathMatcher) (*urlPathMatcher, error) { + stringMatcher, err := internalmatcher.StringMatcherFromProto(pathMatcher.GetPath()) + if err != nil { + return nil, err + } + return &urlPathMatcher{stringMatcher: stringMatcher}, nil +} + +func (upm *urlPathMatcher) match(data *rpcData) bool { + return upm.stringMatcher.Match(data.fullMethod) +} + +// sourceIPMatcher and destinationIPMatcher both are matchers that match against +// a CIDR Range. Two different matchers are needed as the source and ip address +// come from different parts of the data about incoming RPC's passed in. +// Matching a CIDR Range means to determine whether the IP Address falls within +// the CIDR Range or not. They both implement the matcher interface. +type sourceIPMatcher struct { + // ipNet represents the CidrRange that this matcher was configured with. + // This is what will source and destination IP's will be matched against. + ipNet *net.IPNet +} + +func newSourceIPMatcher(cidrRange *v3corepb.CidrRange) (*sourceIPMatcher, error) { + // Convert configuration to a cidrRangeString, as Go standard library has + // methods that parse cidr string. + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &sourceIPMatcher{ipNet: ipNet}, nil +} + +func (sim *sourceIPMatcher) match(data *rpcData) bool { + return sim.ipNet.Contains(net.IP(net.ParseIP(data.peerInfo.Addr.String()))) +} + +type destinationIPMatcher struct { + ipNet *net.IPNet +} + +func newDestinationIPMatcher(cidrRange *v3corepb.CidrRange) (*destinationIPMatcher, error) { + cidrRangeString := fmt.Sprintf("%s/%d", cidrRange.AddressPrefix, cidrRange.PrefixLen.Value) + _, ipNet, err := net.ParseCIDR(cidrRangeString) + if err != nil { + return nil, err + } + return &destinationIPMatcher{ipNet: ipNet}, nil +} + +func (dim *destinationIPMatcher) match(data *rpcData) bool { + return dim.ipNet.Contains(net.IP(net.ParseIP(data.destinationAddr.String()))) +} + +// portMatcher matches on whether the destination port of the RPC matches the +// destination port this matcher was instantiated with. portMatcher +// implements the matcher interface. +type portMatcher struct { + destinationPort uint32 +} + +func newPortMatcher(destinationPort uint32) *portMatcher { + return &portMatcher{destinationPort: destinationPort} +} + +func (pm *portMatcher) match(data *rpcData) bool { + return data.destinationPort == pm.destinationPort +} + +// authenticatedMatcher matches on the name of the Principal. If set, the URI +// SAN or DNS SAN in that order is used from the certificate, otherwise the +// subject field is used. If unset, it applies to any user that is +// authenticated. authenticatedMatcher implements the matcher interface. +type authenticatedMatcher struct { + stringMatcher *internalmatcher.StringMatcher +} + +func newAuthenticatedMatcher(authenticatedMatcherConfig *v3rbacpb.Principal_Authenticated) (*authenticatedMatcher, error) { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). + if authenticatedMatcherConfig.PrincipalName == nil { + return &authenticatedMatcher{}, nil + } + stringMatcher, err := internalmatcher.StringMatcherFromProto(authenticatedMatcherConfig.PrincipalName) + if err != nil { + return nil, err + } + return &authenticatedMatcher{stringMatcher: &stringMatcher}, nil +} + +func (am *authenticatedMatcher) match(data *rpcData) bool { + // Represents this line in the RBAC documentation = "If unset, it applies to + // any user that is authenticated" (see package-level comments). An + // authenticated downstream in a stateful TLS connection will have to + // provide a certificate to prove their identity. Thus, you can simply check + // if there is a certificate present. + if am.stringMatcher == nil { + return len(data.certs) != 0 + } + // No certificate present, so will never successfully match. + if len(data.certs) == 0 { + return false + } + cert := data.certs[0] + // The order of matching as per the RBAC documentation (see package-level comments) + // is as follows: URI SANs, DNS SANs, and then subject name. + for _, uriSAN := range cert.URIs { + if am.stringMatcher.Match(uriSAN.String()) { + return true + } + } + for _, dnsSAN := range cert.DNSNames { + if am.stringMatcher.Match(dnsSAN) { + return true + } + } + return am.stringMatcher.Match(cert.Subject.String()) +} diff --git a/internal/xds/rbac/rbac_engine.go b/internal/xds/rbac/rbac_engine.go new file mode 100644 index 000000000000..609d123c8039 --- /dev/null +++ b/internal/xds/rbac/rbac_engine.go @@ -0,0 +1,221 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Package rbac provides service-level and method-level access control for a +// service. See +// https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/rbac/v3/rbac.proto#role-based-access-control-rbac +// for documentation. +package rbac + +import ( + "context" + "crypto/x509" + "errors" + "fmt" + "net" + "strconv" + + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +// ChainEngine represents a chain of RBAC Engines, used to make authorization +// decisions on incoming RPCs. +type ChainEngine struct { + chainedEngines []*engine +} + +// NewChainEngine returns a chain of RBAC engines, used to make authorization +// decisions on incoming RPCs. Returns a non-nil error for invalid policies. +func NewChainEngine(policies []*v3rbacpb.RBAC) (*ChainEngine, error) { + var engines []*engine + for _, policy := range policies { + engine, err := newEngine(policy) + if err != nil { + return nil, err + } + engines = append(engines, engine) + } + return &ChainEngine{chainedEngines: engines}, nil +} + +// IsAuthorized determines if an incoming RPC is authorized based on the chain of RBAC +// engines and their associated actions. +// +// Errors returned by this function are compatible with the status package. +func (cre *ChainEngine) IsAuthorized(ctx context.Context) error { + // This conversion step (i.e. pulling things out of ctx) can be done once, + // and then be used for the whole chain of RBAC Engines. + rpcData, err := newRPCData(ctx) + if err != nil { + return status.Errorf(codes.InvalidArgument, "missing fields in ctx %+v: %v", ctx, err) + } + for _, engine := range cre.chainedEngines { + matchingPolicyName, ok := engine.findMatchingPolicy(rpcData) + + switch { + case engine.action == v3rbacpb.RBAC_ALLOW && !ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC did not match an allow policy") + case engine.action == v3rbacpb.RBAC_DENY && ok: + return status.Errorf(codes.PermissionDenied, "incoming RPC matched a deny policy %q", matchingPolicyName) + } + // Every policy in the engine list must be queried. Thus, iterate to the + // next policy. + } + // If the incoming RPC gets through all of the engines successfully (i.e. + // doesn't not match an allow or match a deny engine), the RPC is authorized + // to proceed. + return status.Error(codes.OK, "") +} + +// engine is used for matching incoming RPCs to policies. +type engine struct { + policies map[string]*policyMatcher + // action must be ALLOW or DENY. + action v3rbacpb.RBAC_Action +} + +// newEngine creates an RBAC Engine based on the contents of policy. Returns a +// non-nil error if the policy is invalid. +func newEngine(config *v3rbacpb.RBAC) (*engine, error) { + a := *config.Action.Enum() + if a != v3rbacpb.RBAC_ALLOW && a != v3rbacpb.RBAC_DENY { + return nil, fmt.Errorf("unsupported action %s", config.Action) + } + + policies := make(map[string]*policyMatcher, len(config.Policies)) + for name, policy := range config.Policies { + matcher, err := newPolicyMatcher(policy) + if err != nil { + return nil, err + } + policies[name] = matcher + } + return &engine{ + policies: policies, + action: a, + }, nil +} + +// findMatchingPolicy determines if an incoming RPC matches a policy. On a +// successful match, it returns the name of the matching policy and a true bool +// to specify that there was a matching policy found. It returns false in +// the case of not finding a matching policy. +func (r *engine) findMatchingPolicy(rpcData *rpcData) (string, bool) { + for policy, matcher := range r.policies { + if matcher.match(rpcData) { + return policy, true + } + } + return "", false +} + +// newRPCData takes an incoming context (should be a context representing state +// needed for server RPC Call with metadata, peer info (used for source ip/port +// and TLS information) and connection (used for destination ip/port) piped into +// it) and the method name of the Service being called server side and populates +// an rpcData struct ready to be passed to the RBAC Engine to find a matching +// policy. +func newRPCData(ctx context.Context) (*rpcData, error) { + // The caller should populate all of these fields (i.e. for empty headers, + // pipe an empty md into context). + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return nil, errors.New("missing metadata in incoming context") + } + + pi, ok := peer.FromContext(ctx) + if !ok { + return nil, errors.New("missing peer info in incoming context") + } + + // The methodName will be available in the passed in ctx from a unary or streaming + // interceptor, as grpc.Server pipes in a transport stream which contains the methodName + // into contexts available in both unary or streaming interceptors. + mn, ok := grpc.Method(ctx) + if !ok { + return nil, errors.New("missing method in incoming context") + } + + // The connection is needed in order to find the destination address and + // port of the incoming RPC Call. + conn := getConnection(ctx) + if conn == nil { + return nil, errors.New("missing connection in incoming context") + } + _, dPort, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + dp, err := strconv.ParseUint(dPort, 10, 32) + if err != nil { + return nil, fmt.Errorf("error parsing local address: %v", err) + } + + var peerCertificates []*x509.Certificate + if pi.AuthInfo != nil { + tlsInfo, ok := pi.AuthInfo.(credentials.TLSInfo) + if ok { + peerCertificates = tlsInfo.State.PeerCertificates + } + } + + return &rpcData{ + md: md, + peerInfo: pi, + fullMethod: mn, + destinationPort: uint32(dp), + destinationAddr: conn.LocalAddr(), + certs: peerCertificates, + }, nil +} + +// rpcData wraps data pulled from an incoming RPC that the RBAC engine needs to +// find a matching policy. +type rpcData struct { + // md is the HTTP Headers that are present in the incoming RPC. + md metadata.MD + // peerInfo is information about the downstream peer. + peerInfo *peer.Peer + // fullMethod is the method name being called on the upstream service. + fullMethod string + // destinationPort is the port that the RPC is being sent to on the + // server. + destinationPort uint32 + // destinationAddr is the address that the RPC is being sent to. + destinationAddr net.Addr + // certs are the certificates presented by the peer during a TLS + // handshake. + certs []*x509.Certificate +} + +type connectionKey struct{} + +func getConnection(ctx context.Context) net.Conn { + conn, _ := ctx.Value(connectionKey{}).(net.Conn) + return conn +} + +// SetConnection adds the connection to the context to be able to get +// information about the destination ip and port for an incoming RPC. +func SetConnection(ctx context.Context, conn net.Conn) context.Context { + return context.WithValue(ctx, connectionKey{}, conn) +} diff --git a/internal/xds/rbac/rbac_engine_test.go b/internal/xds/rbac/rbac_engine_test.go new file mode 100644 index 000000000000..2521ac4526aa --- /dev/null +++ b/internal/xds/rbac/rbac_engine_test.go @@ -0,0 +1,920 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package rbac + +import ( + "context" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "net" + "net/url" + "testing" + + v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + v3rbacpb "github.com/envoyproxy/go-control-plane/envoy/config/rbac/v3" + v3routepb "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + v3matcherpb "github.com/envoyproxy/go-control-plane/envoy/type/matcher/v3" + v3typepb "github.com/envoyproxy/go-control-plane/envoy/type/v3" + wrapperspb "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type addr struct { + ipAddress string +} + +func (addr) Network() string { return "" } +func (a *addr) String() string { return a.ipAddress } + +// TestNewChainEngine tests the construction of the ChainEngine. Due to some +// types of RBAC configuration being logically wrong and returning an error +// rather than successfully constructing the RBAC Engine, this test tests both +// RBAC Configurations deemed successful and also RBAC Configurations that will +// raise errors. +func (s) TestNewChainEngine(t *testing.T) { + tests := []struct { + name string + policies []*v3rbacpb.RBAC + wantErr bool + }{ + { + name: "SuccessCaseAnyMatchSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseAnyMatchMultiple", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseSimplePolicySingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + // SuccessCaseSimplePolicyTwoPolicies tests the construction of the + // chained engines in the case where there are two policies in a list, + // one with an allow policy and one with a deny policy. A situation + // where two policies (allow and deny) is a very common use case for + // this API, and should successfully build. + { + name: "SuccessCaseSimplePolicyTwoPolicies", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + { + Action: v3rbacpb.RBAC_DENY, + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 8080}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SuccessCaseEnvoyExampleSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + {Rule: &v3rbacpb.Permission_OrRules{OrRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 80}}, + {Rule: &v3rbacpb.Permission_DestinationPort{DestinationPort: 443}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherSuccessSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + }, + { + name: "SourceIpMatcherFailureSingular", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "DestinationIpMatcherSuccess", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "DestinationIpMatcherFailure", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "not a correct address", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "MatcherToNotPolicy", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_NotRule{NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + { + name: "MatcherToNotPrinicipal", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "not-from-certain-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_NotId{NotId: &v3rbacpb.Principal{Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}}}, + }, + }, + }, + }, + }, + }, + // PrinicpalProductViewer tests the construction of a chained engine + // with a policy that allows any downstream to send a GET request on a + // certain path. + { + name: "PrincipalProductViewer", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_ALLOW, + Policies: map[string]*v3rbacpb.Policy{ + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_AndIds{AndIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{ + Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/books"}}}}}}, + {Identifier: &v3rbacpb.Principal_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/cars"}}}}}}, + }, + }}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + // Certain Headers tests the construction of a chained engine with a + // policy that allows any downstream to send an HTTP request with + // certain headers. + { + name: "CertainHeaders", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-headers": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + { + Identifier: &v3rbacpb.Principal_OrIds{OrIds: &v3rbacpb.Principal_Set{Ids: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SafeRegexMatch{SafeRegexMatch: &v3matcherpb.RegexMatcher{Regex: "GET"}}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_RangeMatch{RangeMatch: &v3typepb.Int64Range{ + Start: 0, + End: 64, + }}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PresentMatch{PresentMatch: true}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{PrefixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_SuffixMatch{SuffixMatch: "GET"}}}}, + {Identifier: &v3rbacpb.Principal_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ContainsMatch{ContainsMatch: "GET"}}}}, + }}}, + }, + }, + }, + }, + }, + }, + }, + { + name: "LogAction", + policies: []*v3rbacpb.RBAC{ + { + Action: v3rbacpb.RBAC_LOG, + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + wantErr: true, + }, + { + name: "ActionNotSpecified", + policies: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if _, err := NewChainEngine(test.policies); (err != nil) != test.wantErr { + t.Fatalf("NewChainEngine(%+v) returned err: %v, wantErr: %v", test.policies, err, test.wantErr) + } + }) + } +} + +// TestChainEngine tests the chain of RBAC Engines by configuring the chain of +// engines in a certain way in different scenarios. After configuring the chain +// of engines in a certain way, this test pings the chain of engines with +// different types of data representing incoming RPC's (piped into a context), +// and verifies that it works as expected. +func (s) TestChainEngine(t *testing.T) { + tests := []struct { + name string + rbacConfigs []*v3rbacpb.RBAC + rbacQueries []struct { + rpcData *rpcData + wantStatusCode codes.Code + } + }{ + // SuccessCaseAnyMatch tests a single RBAC Engine instantiated with + // a config with a policy with any rules for both permissions and + // principals, meaning that any data about incoming RPC's that the RBAC + // Engine is queried with should match that policy. + { + name: "SuccessCaseAnyMatch", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "anyone": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + }, + }, + // SuccessCaseSimplePolicy is a test that tests a single policy + // that only allows an rpc to proceed if the rpc is calling with a certain + // path. + { + name: "SuccessCaseSimplePolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the local host fan policy. Thus, + // this RPC should be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + + // This RPC shouldn't match with the local host fan policy. Thus, + // this rpc shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // SuccessCaseEnvoyExample is a test based on the example provided + // in the EnvoyProxy docs. The RBAC Config contains two policies, + // service admin and product viewer, that provides an example of a real + // RBAC Config that might be configured for a given for a given backend + // service. + { + name: "SuccessCaseEnvoyExample", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "service-admin": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/admin"}}}}}, + {Identifier: &v3rbacpb.Principal_Authenticated_{Authenticated: &v3rbacpb.Principal_Authenticated{PrincipalName: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "//cluster.local/ns/default/sa/superuser"}}}}}, + }, + }, + "product-viewer": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_AndRules{AndRules: &v3rbacpb.Permission_Set{ + Rules: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Header{Header: &v3routepb.HeaderMatcher{Name: ":method", HeaderMatchSpecifier: &v3routepb.HeaderMatcher_ExactMatch{ExactMatch: "GET"}}}}, + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/products"}}}}}}, + }, + }, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the service admin + // policy. + { + rpcData: &rpcData{ + fullMethod: "some method", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + URIs: []*url.URL{ + { + Host: "cluster.local", + Path: "/ns/default/sa/admin", + }, + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.OK, + }, + // These incoming RPC calls should not match any policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + fullMethod: "get-product-list", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + AuthInfo: credentials.TLSInfo{ + State: tls.ConnectionState{ + PeerCertificates: []*x509.Certificate{ + { + Subject: pkix.Name{ + CommonName: "localhost", + }, + }, + }, + }, + }, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "NotMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "not-secret-content": { + Permissions: []*v3rbacpb.Permission{ + { + Rule: &v3rbacpb.Permission_NotRule{ + NotRule: &v3rbacpb.Permission{Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Prefix{Prefix: "/secret-content"}}}}}}, + }, + }, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the not-secret-content policy. + { + rpcData: &rpcData{ + fullMethod: "/regular-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the not-secret-content-policy. + { + rpcData: &rpcData{ + fullMethod: "/secret-content", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "SourceIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call should match with the certain-source-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This incoming RPC Call shouldn't match with the certain-source-ip policy. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + { + name: "DestinationIpMatcher", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-destination-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_DestinationIp{DestinationIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This incoming RPC Call shouldn't match with the + // certain-destination-ip policy, as the test listens on local + // host. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + // AllowAndDenyPolicy tests a policy with an allow (on path) and + // deny (on port) policy chained together. This represents how a user + // configured interceptor would use this, and also is a potential + // configuration for a dynamic xds interceptor. + { + name: "AllowAndDenyPolicy", + rbacConfigs: []*v3rbacpb.RBAC{ + { + Policies: map[string]*v3rbacpb.Policy{ + "certain-source-ip": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_Any{Any: true}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_DirectRemoteIp{DirectRemoteIp: &v3corepb.CidrRange{AddressPrefix: "0.0.0.0", PrefixLen: &wrapperspb.UInt32Value{Value: uint32(10)}}}}, + }, + }, + }, + Action: v3rbacpb.RBAC_ALLOW, + }, + { + Policies: map[string]*v3rbacpb.Policy{ + "localhost-fan": { + Permissions: []*v3rbacpb.Permission{ + {Rule: &v3rbacpb.Permission_UrlPath{UrlPath: &v3matcherpb.PathMatcher{Rule: &v3matcherpb.PathMatcher_Path{Path: &v3matcherpb.StringMatcher{MatchPattern: &v3matcherpb.StringMatcher_Exact{Exact: "localhost-fan-page"}}}}}}, + }, + Principals: []*v3rbacpb.Principal{ + {Identifier: &v3rbacpb.Principal_Any{Any: true}}, + }, + }, + }, + Action: v3rbacpb.RBAC_DENY, + }, + }, + rbacQueries: []struct { + rpcData *rpcData + wantStatusCode codes.Code + }{ + // This RPC should match with the allow policy, and shouldn't + // match with the deny and thus should be allowed to proceed. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.OK, + }, + // This RPC should match with both the allow policy and deny policy + // and thus shouldn't be allowed to proceed as matched with deny. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "0.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with either policy, and thus + // shouldn't be allowed to proceed as didn't match with allow. + { + rpcData: &rpcData{ + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + // This RPC shouldn't match with allow, match with deny, and + // thus shouldn't be allowed to proceed. + { + rpcData: &rpcData{ + fullMethod: "localhost-fan-page", + peerInfo: &peer.Peer{ + Addr: &addr{ipAddress: "10.0.0.0"}, + }, + }, + wantStatusCode: codes.PermissionDenied, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Instantiate the chainedRBACEngine with different configurations that are + // interesting to test and to query. + cre, err := NewChainEngine(test.rbacConfigs) + if err != nil { + t.Fatalf("Error constructing RBAC Engine: %v", err) + } + // Query the created chain of RBAC Engines with different args to see + // if the chain of RBAC Engines configured as such works as intended. + for _, data := range test.rbacQueries { + func() { + // Construct the context with three data points that have enough + // information to represent incoming RPC's. This will be how a + // user uses this API. A user will have to put MD, PeerInfo, and + // the connection the RPC is sent on in the context. + ctx := metadata.NewIncomingContext(context.Background(), data.rpcData.md) + + // Make a TCP connection with a certain destination port. The + // address/port of this connection will be used to populate the + // destination ip/port in RPCData struct. This represents what + // the user of ChainEngine will have to place into + // context, as this is only way to get destination ip and port. + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Error listening: %v", err) + } + defer lis.Close() + connCh := make(chan net.Conn, 1) + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + _, err = net.Dial("tcp", lis.Addr().String()) + if err != nil { + t.Fatalf("Error dialing: %v", err) + } + conn := <-connCh + defer conn.Close() + ctx = SetConnection(ctx, conn) + ctx = peer.NewContext(ctx, data.rpcData.peerInfo) + stream := &ServerTransportStreamWithMethod{ + method: data.rpcData.fullMethod, + } + + ctx = grpc.NewContextWithServerTransportStream(ctx, stream) + err = cre.IsAuthorized(ctx) + if gotCode := status.Code(err); gotCode != data.wantStatusCode { + t.Fatalf("IsAuthorized(%+v, %+v) returned (%+v), want(%+v)", ctx, data.rpcData.fullMethod, gotCode, data.wantStatusCode) + } + }() + } + }) + } +} + +type ServerTransportStreamWithMethod struct { + method string +} + +func (sts *ServerTransportStreamWithMethod) Method() string { + return sts.method +} + +func (sts *ServerTransportStreamWithMethod) SetHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SendHeader(md metadata.MD) error { + return nil +} + +func (sts *ServerTransportStreamWithMethod) SetTrailer(md metadata.MD) error { + return nil +} diff --git a/metadata/metadata.go b/metadata/metadata.go index e4cbea917498..3604c7819fdc 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -93,12 +93,16 @@ func (md MD) Copy() MD { } // Get obtains the values for a given key. +// +// k is converted to lowercase before searching in md. func (md MD) Get(k string) []string { k = strings.ToLower(k) return md[k] } // Set sets the value of a given key with a slice of values. +// +// k is converted to lowercase before storing in md. func (md MD) Set(k string, vals ...string) { if len(vals) == 0 { return @@ -107,7 +111,10 @@ func (md MD) Set(k string, vals ...string) { md[k] = vals } -// Append adds the values to key k, not overwriting what was already stored at that key. +// Append adds the values to key k, not overwriting what was already stored at +// that key. +// +// k is converted to lowercase before storing in md. func (md MD) Append(k string, vals ...string) { if len(vals) == 0 { return @@ -116,9 +123,17 @@ func (md MD) Append(k string, vals ...string) { md[k] = append(md[k], vals...) } +// Delete removes the values for a given key k which is converted to lowercase +// before removing it from md. +func (md MD) Delete(k string) { + k = strings.ToLower(k) + delete(md, k) +} + // Join joins any number of mds into a single MD. -// The order of values for each key is determined by the order in which -// the mds containing those values are presented to Join. +// +// The order of values for each key is determined by the order in which the mds +// containing those values are presented to Join. func Join(mds ...MD) MD { out := MD{} for _, md := range mds { @@ -145,8 +160,8 @@ func NewOutgoingContext(ctx context.Context, md MD) context.Context { } // AppendToOutgoingContext returns a new context with the provided kv merged -// with any existing metadata in the context. Please refer to the -// documentation of Pairs for a description of kv. +// with any existing metadata in the context. Please refer to the documentation +// of Pairs for a description of kv. func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context { if len(kv)%2 == 1 { panic(fmt.Sprintf("metadata: AppendToOutgoingContext got an odd number of input pairs for metadata: %d", len(kv))) @@ -159,20 +174,34 @@ func AppendToOutgoingContext(ctx context.Context, kv ...string) context.Context return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md.md, added: added}) } -// FromIncomingContext returns the incoming metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. -func FromIncomingContext(ctx context.Context) (md MD, ok bool) { - md, ok = ctx.Value(mdIncomingKey{}).(MD) - return +// FromIncomingContext returns the incoming metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. +func FromIncomingContext(ctx context.Context) (MD, bool) { + md, ok := ctx.Value(mdIncomingKey{}).(MD) + if !ok { + return nil, false + } + out := MD{} + for k, v := range md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } + return out, true } -// FromOutgoingContextRaw returns the un-merged, intermediary contents -// of rawMD. Remember to perform strings.ToLower on the keys. The returned -// MD should not be modified. Writing to it may cause races. Modification -// should be made to copies of the returned MD. +// FromOutgoingContextRaw returns the un-merged, intermediary contents of rawMD. +// +// Remember to perform strings.ToLower on the keys, for both the returned MD (MD +// is a map, there's no guarantee it's created using our helper functions) and +// the extra kv pairs (AppendToOutgoingContext doesn't turn them into +// lowercase). // -// This is intended for gRPC-internal use ONLY. +// This is intended for gRPC-internal use ONLY. Users should use +// FromOutgoingContext instead. func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { @@ -182,16 +211,23 @@ func FromOutgoingContextRaw(ctx context.Context) (MD, [][]string, bool) { return raw.md, raw.added, true } -// FromOutgoingContext returns the outgoing metadata in ctx if it exists. The -// returned MD should not be modified. Writing to it may cause races. -// Modification should be made to copies of the returned MD. +// FromOutgoingContext returns the outgoing metadata in ctx if it exists. +// +// All keys in the returned MD are lowercase. func FromOutgoingContext(ctx context.Context) (MD, bool) { raw, ok := ctx.Value(mdOutgoingKey{}).(rawMD) if !ok { return nil, false } - out := raw.md.Copy() + out := MD{} + for k, v := range raw.md { + // We need to manually convert all keys to lower case, because MD is a + // map, and there's no guarantee that the MD attached to the context is + // created using our helper functions. + key := strings.ToLower(k) + out[key] = v + } for _, added := range raw.added { if len(added)%2 == 1 { panic(fmt.Sprintf("metadata: FromOutgoingContext got an odd number of input pairs for metadata: %d", len(added))) diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index f1fb5f6d324e..89be06eaada0 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -169,6 +169,35 @@ func (s) TestAppend(t *testing.T) { } } +func (s) TestDelete(t *testing.T) { + for _, test := range []struct { + md MD + deleteKey string + want MD + }{ + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "My-Optional-Header", + want: Pairs(), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "Other-Key", + want: Pairs("my-optional-header", "42"), + }, + { + md: Pairs("My-Optional-Header", "42"), + deleteKey: "my-OptIoNal-HeAder", + want: Pairs(), + }, + } { + test.md.Delete(test.deleteKey) + if !reflect.DeepEqual(test.md, test.want) { + t.Errorf("value of metadata is %v, want %v", test.md, test.want) + } + } +} + func (s) TestAppendToOutgoingContext(t *testing.T) { // Pre-existing metadata tCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) diff --git a/picker_wrapper.go b/picker_wrapper.go index a58174b6f436..0878ada9dbb2 100644 --- a/picker_wrapper.go +++ b/picker_wrapper.go @@ -147,7 +147,7 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer. logger.Error("subconn returned from pick is not *acBalancerWrapper") continue } - if t, ok := acw.getAddrConn().getReadyTransport(); ok { + if t := acw.getAddrConn().getReadyTransport(); t != nil { if channelz.IsOn() { return t, doneChannelzWrapper(acw, pickResult.Done), nil } diff --git a/resolver/manual/manual.go b/resolver/manual/manual.go index 3679d702ab96..f6e7b5ae3581 100644 --- a/resolver/manual/manual.go +++ b/resolver/manual/manual.go @@ -27,7 +27,9 @@ import ( // NewBuilderWithScheme creates a new test resolver builder with the given scheme. func NewBuilderWithScheme(scheme string) *Resolver { return &Resolver{ + BuildCallback: func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) {}, ResolveNowCallback: func(resolver.ResolveNowOptions) {}, + CloseCallback: func() {}, scheme: scheme, } } @@ -35,11 +37,17 @@ func NewBuilderWithScheme(scheme string) *Resolver { // Resolver is also a resolver builder. // It's build() function always returns itself. type Resolver struct { + // BuildCallback is called when the Build method is called. Must not be + // nil. Must not be changed after the resolver may be built. + BuildCallback func(resolver.Target, resolver.ClientConn, resolver.BuildOptions) // ResolveNowCallback is called when the ResolveNow method is called on the // resolver. Must not be nil. Must not be changed after the resolver may // be built. ResolveNowCallback func(resolver.ResolveNowOptions) - scheme string + // CloseCallback is called when the Close method is called. Must not be + // nil. Must not be changed after the resolver may be built. + CloseCallback func() + scheme string // Fields actually belong to the resolver. CC resolver.ClientConn @@ -54,6 +62,7 @@ func (r *Resolver) InitialState(s resolver.State) { // Build returns itself for Resolver, because it's both a builder and a resolver. func (r *Resolver) Build(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOptions) (resolver.Resolver, error) { + r.BuildCallback(target, cc, opts) r.CC = cc if r.bootstrapState != nil { r.UpdateState(*r.bootstrapState) @@ -72,9 +81,16 @@ func (r *Resolver) ResolveNow(o resolver.ResolveNowOptions) { } // Close is a noop for Resolver. -func (*Resolver) Close() {} +func (r *Resolver) Close() { + r.CloseCallback() +} // UpdateState calls CC.UpdateState. func (r *Resolver) UpdateState(s resolver.State) { r.CC.UpdateState(s) } + +// ReportError calls CC.ReportError. +func (r *Resolver) ReportError(err error) { + r.CC.ReportError(err) +} diff --git a/resolver_conn_wrapper.go b/resolver_conn_wrapper.go index 4118de571ab5..2c47cd54f07c 100644 --- a/resolver_conn_wrapper.go +++ b/resolver_conn_wrapper.go @@ -39,6 +39,8 @@ type ccResolverWrapper struct { resolver resolver.Resolver done *grpcsync.Event curState resolver.State + + incomingMu sync.Mutex // Synchronizes all the incoming calls. } // newCCResolverWrapper uses the resolver.Builder to build a Resolver and @@ -90,6 +92,8 @@ func (ccr *ccResolverWrapper) close() { } func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return nil } @@ -105,6 +109,8 @@ func (ccr *ccResolverWrapper) UpdateState(s resolver.State) error { } func (ccr *ccResolverWrapper) ReportError(err error) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -114,6 +120,8 @@ func (ccr *ccResolverWrapper) ReportError(err error) { // NewAddress is called by the resolver implementation to send addresses to gRPC. func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } @@ -128,6 +136,8 @@ func (ccr *ccResolverWrapper) NewAddress(addrs []resolver.Address) { // NewServiceConfig is called by the resolver implementation to send service // configs to gRPC. func (ccr *ccResolverWrapper) NewServiceConfig(sc string) { + ccr.incomingMu.Lock() + defer ccr.incomingMu.Unlock() if ccr.done.HasFired() { return } diff --git a/rpc_util.go b/rpc_util.go index 6db356fa56a7..87987a2e652f 100644 --- a/rpc_util.go +++ b/rpc_util.go @@ -258,7 +258,8 @@ func (o PeerCallOption) after(c *callInfo, attempt *csAttempt) { } // WaitForReady configures the action to take when an RPC is attempted on broken -// connections or unreachable servers. If waitForReady is false, the RPC will fail +// connections or unreachable servers. If waitForReady is false and the +// connection is in the TRANSIENT_FAILURE state, the RPC will fail // immediately. Otherwise, the RPC client will block the call until a // connection is available (or the call is canceled or times out) and will // retry the call if it fails due to a transient error. gRPC will not retry if @@ -828,26 +829,28 @@ func Errorf(c codes.Code, format string, a ...interface{}) error { // toRPCErr converts an error into an error from the status package. func toRPCErr(err error) error { - if err == nil || err == io.EOF { + switch err { + case nil, io.EOF: return err - } - if err == io.ErrUnexpectedEOF { + case context.DeadlineExceeded: + return status.Error(codes.DeadlineExceeded, err.Error()) + case context.Canceled: + return status.Error(codes.Canceled, err.Error()) + case io.ErrUnexpectedEOF: return status.Error(codes.Internal, err.Error()) } - if _, ok := status.FromError(err); ok { - return err - } + switch e := err.(type) { case transport.ConnectionError: return status.Error(codes.Unavailable, e.Desc) - default: - switch err { - case context.DeadlineExceeded: - return status.Error(codes.DeadlineExceeded, err.Error()) - case context.Canceled: - return status.Error(codes.Canceled, err.Error()) - } + case *transport.NewStreamError: + return toRPCErr(e.Err) + } + + if _, ok := status.FromError(err); ok { + return err } + return status.Error(codes.Unknown, err.Error()) } diff --git a/security/advancedtls/crl.go b/security/advancedtls/crl.go new file mode 100644 index 000000000000..5b3f90127ee7 --- /dev/null +++ b/security/advancedtls/crl.go @@ -0,0 +1,499 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "bytes" + "crypto/sha1" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io/ioutil" + "path/filepath" + "strings" + "time" + + "google.golang.org/grpc/grpclog" +) + +var grpclogLogger = grpclog.Component("advancedtls") + +// Cache is an interface to cache CRL files. +// The cache implemetation must be concurrency safe. +// A fixed size lru cache from golang-lru is recommended. +type Cache interface { + // Add adds a value to the cache. + Add(key, value interface{}) bool + // Get looks up a key's value from the cache. + Get(key interface{}) (value interface{}, ok bool) +} + +// RevocationConfig contains options for CRL lookup. +type RevocationConfig struct { + // RootDir is the directory to search for CRL files. + // Directory format must match OpenSSL X509_LOOKUP_hash_dir(3). + RootDir string + // AllowUndetermined controls if certificate chains with RevocationUndetermined + // revocation status are allowed to complete. + AllowUndetermined bool + // Cache will store CRL files if not nil, otherwise files are reloaded for every lookup. + Cache Cache +} + +// RevocationStatus is the revocation status for a certificate or chain. +type RevocationStatus int + +const ( + // RevocationUndetermined means we couldn't find or verify a CRL for the cert. + RevocationUndetermined RevocationStatus = iota + // RevocationUnrevoked means we found the CRL for the cert and the cert is not revoked. + RevocationUnrevoked + // RevocationRevoked means we found the CRL and the cert is revoked. + RevocationRevoked +) + +func (s RevocationStatus) String() string { + return [...]string{"RevocationUndetermined", "RevocationUnrevoked", "RevocationRevoked"}[s] +} + +// certificateListExt contains a pkix.CertificateList and parsed +// extensions that aren't provided by the golang CRL parser. +type certificateListExt struct { + CertList *pkix.CertificateList + // RFC5280, 5.2.1, all conforming CRLs must have a AKID with the ID method. + AuthorityKeyID []byte +} + +const tagDirectoryName = 4 + +var ( + // RFC5280, 5.2.4 id-ce-deltaCRLIndicator OBJECT IDENTIFIER ::= { id-ce 27 } + oidDeltaCRLIndicator = asn1.ObjectIdentifier{2, 5, 29, 27} + // RFC5280, 5.2.5 id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + oidIssuingDistributionPoint = asn1.ObjectIdentifier{2, 5, 29, 28} + // RFC5280, 5.3.3 id-ce-certificateIssuer OBJECT IDENTIFIER ::= { id-ce 29 } + oidCertificateIssuer = asn1.ObjectIdentifier{2, 5, 29, 29} + // RFC5290, 4.2.1.1 id-ce-authorityKeyIdentifier OBJECT IDENTIFIER ::= { id-ce 35 } + oidAuthorityKeyIdentifier = asn1.ObjectIdentifier{2, 5, 29, 35} +) + +// x509NameHash implements the OpenSSL X509_NAME_hash function for hashed directory lookups. +func x509NameHash(r pkix.RDNSequence) string { + var canonBytes []byte + // First, canonicalize all the strings. + for _, rdnSet := range r { + for i, rdn := range rdnSet { + value, ok := rdn.Value.(string) + if !ok { + continue + } + // OpenSSL trims all whitespace, does a tolower, and removes extra spaces between words. + // Implemented in x509_name_canon in OpenSSL + canonStr := strings.Join(strings.Fields( + strings.TrimSpace(strings.ToLower(value))), " ") + // Then it changes everything to UTF8 strings + rdnSet[i].Value = asn1.RawValue{Tag: asn1.TagUTF8String, Bytes: []byte(canonStr)} + + } + } + + // Finally, OpenSSL drops the initial sequence tag + // so we marshal all the RDNs separately instead of as a group. + for _, canonRdn := range r { + b, err := asn1.Marshal(canonRdn) + if err != nil { + continue + } + canonBytes = append(canonBytes, b...) + } + + issuerHash := sha1.Sum(canonBytes) + // Openssl takes the first 4 bytes and encodes them as a little endian + // uint32 and then uses the hex to make the file name. + // In C++, this would be: + // (((unsigned long)md[0]) | ((unsigned long)md[1] << 8L) | + // ((unsigned long)md[2] << 16L) | ((unsigned long)md[3] << 24L) + // ) & 0xffffffffL; + fileHash := binary.LittleEndian.Uint32(issuerHash[0:4]) + return fmt.Sprintf("%08x", fileHash) +} + +// CheckRevocation checks the connection for revoked certificates based on RFC5280. +// This implementation has the following major limitations: +// * Indirect CRL files are not supported. +// * CRL loading is only supported from directories in the X509_LOOKUP_hash_dir format. +// * OnlySomeReasons is not supported. +// * Delta CRL files are not supported. +// * Certificate CRLDistributionPoint must be URLs, but are then ignored and converted into a file path. +// * CRL checks are done after path building, which goes against RFC4158. +func CheckRevocation(conn tls.ConnectionState, cfg RevocationConfig) error { + return CheckChainRevocation(conn.VerifiedChains, cfg) +} + +// CheckChainRevocation checks the verified certificate chain +// for revoked certificates based on RFC5280. +func CheckChainRevocation(verifiedChains [][]*x509.Certificate, cfg RevocationConfig) error { + // Iterate the verified chains looking for one that is RevocationUnrevoked. + // A single RevocationUnrevoked chain is enough to allow the connection, and a single RevocationRevoked + // chain does not mean the connection should fail. + count := make(map[RevocationStatus]int) + for _, chain := range verifiedChains { + switch checkChain(chain, cfg) { + case RevocationUnrevoked: + // If any chain is RevocationUnrevoked then return no error. + return nil + case RevocationRevoked: + // If this chain is revoked, keep looking for another chain. + count[RevocationRevoked]++ + continue + case RevocationUndetermined: + if cfg.AllowUndetermined { + return nil + } + count[RevocationUndetermined]++ + continue + } + } + return fmt.Errorf("no unrevoked chains found: %v", count) +} + +// checkChain will determine and check all certificates in chain against the CRL +// defined in the certificate with the following rules: +// 1. If any certificate is RevocationRevoked, return RevocationRevoked. +// 2. If any certificate is RevocationUndetermined, return RevocationUndetermined. +// 3. If all certificates are RevocationUnrevoked, return RevocationUnrevoked. +func checkChain(chain []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + chainStatus := RevocationUnrevoked + for _, c := range chain { + switch checkCert(c, chain, cfg) { + case RevocationRevoked: + // Easy case, if a cert in the chain is revoked, the chain is revoked. + return RevocationRevoked + case RevocationUndetermined: + // If we couldn't find the revocation status for a cert, the chain is at best RevocationUndetermined + // keep looking to see if we find a cert in the chain that's RevocationRevoked, + // but return RevocationUndetermined at a minimum. + chainStatus = RevocationUndetermined + case RevocationUnrevoked: + // Continue iterating up the cert chain. + continue + } + } + return chainStatus +} + +func cachedCrl(rawIssuer []byte, cache Cache) (*certificateListExt, bool) { + val, ok := cache.Get(hex.EncodeToString(rawIssuer)) + if !ok { + return nil, false + } + crl, ok := val.(*certificateListExt) + if !ok { + return nil, false + } + // If the CRL is expired, force a reload. + if crl.CertList.HasExpired(time.Now()) { + return nil, false + } + return crl, true +} + +// fetchIssuerCRL fetches and verifies the CRL for rawIssuer from disk or cache if configured in cfg. +func fetchIssuerCRL(crlDistributionPoint string, rawIssuer []byte, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) (*certificateListExt, error) { + if cfg.Cache != nil { + if crl, ok := cachedCrl(rawIssuer, cfg.Cache); ok { + return crl, nil + } + } + + crl, err := fetchCRL(crlDistributionPoint, rawIssuer, cfg) + if err != nil { + return nil, fmt.Errorf("fetchCRL(%v) failed err = %v", crlDistributionPoint, err) + } + + if err := verifyCRL(crl, rawIssuer, crlVerifyCrt); err != nil { + return nil, fmt.Errorf("verifyCRL(%v) failed err = %v", crlDistributionPoint, err) + } + if cfg.Cache != nil { + cfg.Cache.Add(hex.EncodeToString(rawIssuer), crl) + } + return crl, nil +} + +// checkCert checks a single certificate against the CRL defined in the certificate. +// It will fetch and verify the CRL(s) defined by CRLDistributionPoints. +// If we can't load any authoritative CRL files, the status is RevocationUndetermined. +// c is the certificate to check. +// crlVerifyCrt is the group of possible certificates to verify the crl. +func checkCert(c *x509.Certificate, crlVerifyCrt []*x509.Certificate, cfg RevocationConfig) RevocationStatus { + if len(c.CRLDistributionPoints) == 0 { + return RevocationUnrevoked + } + // Iterate through CRL distribution points to check for status + for _, dp := range c.CRLDistributionPoints { + crl, err := fetchIssuerCRL(dp, c.RawIssuer, crlVerifyCrt, cfg) + if err != nil { + grpclogLogger.Warningf("getIssuerCRL(%v) err = %v", c.Issuer, err) + continue + } + revocation, err := checkCertRevocation(c, crl) + if err != nil { + grpclogLogger.Warningf("checkCertRevocation(CRL %v) failed %v", crl.CertList.TBSCertList.Issuer, err) + // We couldn't check the CRL file for some reason, so continue + // to the next file + continue + } + // Here we've gotten a CRL that loads and verifies. + // We only handle all-reasons CRL files, so this file + // is authoritative for the certificate. + return revocation + + } + // We couldn't load any CRL files for the certificate, so we don't know if it's RevocationUnrevoked or not. + return RevocationUndetermined +} + +func checkCertRevocation(c *x509.Certificate, crl *certificateListExt) (RevocationStatus, error) { + // Per section 5.3.3 we prime the certificate issuer with the CRL issuer. + // Subsequent entries use the previous entry's issuer. + rawEntryIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return RevocationUndetermined, err + } + + // Loop through all the revoked certificates. + for _, revCert := range crl.CertList.TBSCertList.RevokedCertificates { + // 5.3 Loop through CRL entry extensions for needed information. + for _, ext := range revCert.Extensions { + if oidCertificateIssuer.Equal(ext.Id) { + extIssuer, err := parseCertIssuerExt(ext) + if err != nil { + grpclogLogger.Info(err) + if ext.Critical { + return RevocationUndetermined, err + } + // Since this is a non-critical extension, we can skip it even though + // there was a parsing failure. + continue + } + rawEntryIssuer = extIssuer + } else if ext.Critical { + return RevocationUndetermined, fmt.Errorf("checkCertRevocation: Unhandled critical extension: %v", ext.Id) + } + } + + // If the issuer and serial number appear in the CRL, the certificate is revoked. + if bytes.Equal(c.RawIssuer, rawEntryIssuer) && c.SerialNumber.Cmp(revCert.SerialNumber) == 0 { + // CRL contains the serial, so return revoked. + return RevocationRevoked, nil + } + } + // We did not find the serial in the CRL file that was valid for the cert + // so the certificate is not revoked. + return RevocationUnrevoked, nil +} + +func parseCertIssuerExt(ext pkix.Extension) ([]byte, error) { + // 5.3.3 Certificate Issuer + // CertificateIssuer ::= GeneralNames + // GeneralNames ::= SEQUENCE SIZE (1..MAX) OF GeneralName + var generalNames []asn1.RawValue + if rest, err := asn1.Unmarshal(ext.Value, &generalNames); err != nil || len(rest) != 0 { + return nil, fmt.Errorf("asn1.Unmarshal failed err = %v", err) + } + + for _, generalName := range generalNames { + // GeneralName ::= CHOICE { + // otherName [0] OtherName, + // rfc822Name [1] IA5String, + // dNSName [2] IA5String, + // x400Address [3] ORAddress, + // directoryName [4] Name, + // ediPartyName [5] EDIPartyName, + // uniformResourceIdentifier [6] IA5String, + // iPAddress [7] OCTET STRING, + // registeredID [8] OBJECT IDENTIFIER } + if generalName.Tag == tagDirectoryName { + return generalName.Bytes, nil + } + } + // Conforming CRL issuers MUST include in this extension the + // distinguished name (DN) from the issuer field of the certificate that + // corresponds to this CRL entry. + // If we couldn't get a directoryName, we can't reason about this file so cert status is + // RevocationUndetermined. + return nil, errors.New("no DN found in certificate issuer") +} + +// RFC 5280, 4.2.1.1 +type authKeyID struct { + ID []byte `asn1:"optional,tag:0"` +} + +// RFC5280, 5.2.5 +// id-ce-issuingDistributionPoint OBJECT IDENTIFIER ::= { id-ce 28 } + +// IssuingDistributionPoint ::= SEQUENCE { +// distributionPoint [0] DistributionPointName OPTIONAL, +// onlyContainsUserCerts [1] BOOLEAN DEFAULT FALSE, +// onlyContainsCACerts [2] BOOLEAN DEFAULT FALSE, +// onlySomeReasons [3] ReasonFlags OPTIONAL, +// indirectCRL [4] BOOLEAN DEFAULT FALSE, +// onlyContainsAttributeCerts [5] BOOLEAN DEFAULT FALSE } + +// -- at most one of onlyContainsUserCerts, onlyContainsCACerts, +// -- and onlyContainsAttributeCerts may be set to TRUE. +type issuingDistributionPoint struct { + DistributionPoint asn1.RawValue `asn1:"optional,tag:0"` + OnlyContainsUserCerts bool `asn1:"optional,tag:1"` + OnlyContainsCACerts bool `asn1:"optional,tag:2"` + OnlySomeReasons asn1.BitString `asn1:"optional,tag:3"` + IndirectCRL bool `asn1:"optional,tag:4"` + OnlyContainsAttributeCerts bool `asn1:"optional,tag:5"` +} + +// parseCRLExtensions parses the extensions for a CRL +// and checks that they're supported by the parser. +func parseCRLExtensions(c *pkix.CertificateList) (*certificateListExt, error) { + if c == nil { + return nil, errors.New("c is nil, expected any value") + } + certList := &certificateListExt{CertList: c} + + for _, ext := range c.TBSCertList.Extensions { + switch { + case oidDeltaCRLIndicator.Equal(ext.Id): + return nil, fmt.Errorf("delta CRLs unsupported") + + case oidAuthorityKeyIdentifier.Equal(ext.Id): + var a authKeyID + if rest, err := asn1.Unmarshal(ext.Value, &a); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after AKID extension") + } + certList.AuthorityKeyID = a.ID + + case oidIssuingDistributionPoint.Equal(ext.Id): + var dp issuingDistributionPoint + if rest, err := asn1.Unmarshal(ext.Value, &dp); err != nil { + return nil, fmt.Errorf("asn1.Unmarshal failed. err = %v", err) + } else if len(rest) != 0 { + return nil, errors.New("trailing data after IssuingDistributionPoint extension") + } + + if dp.OnlyContainsUserCerts || dp.OnlyContainsCACerts || dp.OnlyContainsAttributeCerts { + return nil, errors.New("CRL only contains some certificate types") + } + if dp.IndirectCRL { + return nil, errors.New("indirect CRLs unsupported") + } + if dp.OnlySomeReasons.BitLength != 0 { + return nil, errors.New("onlySomeReasons unsupported") + } + + case ext.Critical: + return nil, fmt.Errorf("unsupported critical extension: %v", ext.Id) + } + } + + if len(certList.AuthorityKeyID) == 0 { + return nil, errors.New("authority key identifier extension missing") + } + return certList, nil +} + +func fetchCRL(loc string, rawIssuer []byte, cfg RevocationConfig) (*certificateListExt, error) { + var parsedCRL *certificateListExt + // 6.3.3 (a) (1) (ii) + // According to X509_LOOKUP_hash_dir the format is issuer_hash.rN where N is an increasing number. + // There are no gaps, so we break when we can't find a file. + for i := 0; ; i++ { + // Unmarshal to RDNSeqence according to http://go/godoc/crypto/x509/pkix/#Name. + var r pkix.RDNSequence + rest, err := asn1.Unmarshal(rawIssuer, &r) + if len(rest) != 0 || err != nil { + return nil, fmt.Errorf("asn1.Unmarshal(Issuer) len(rest) = %v, err = %v", len(rest), err) + } + crlPath := fmt.Sprintf("%s.r%d", filepath.Join(cfg.RootDir, x509NameHash(r)), i) + crlBytes, err := ioutil.ReadFile(crlPath) + if err != nil { + // Break when we can't read a CRL file. + grpclogLogger.Infof("readFile: %v", err) + break + } + + crl, err := x509.ParseCRL(crlBytes) + if err != nil { + // Parsing errors for a CRL shouldn't happen so fail. + return nil, fmt.Errorf("x509.ParseCrl(%v) failed err = %v", crlPath, err) + } + var certList *certificateListExt + if certList, err = parseCRLExtensions(crl); err != nil { + grpclogLogger.Infof("fetchCRL: unsupported crl %v, err = %v", crlPath, err) + // Continue to find a supported CRL + continue + } + + rawCRLIssuer, err := asn1.Marshal(certList.CertList.TBSCertList.Issuer) + if err != nil { + return nil, fmt.Errorf("asn1.Marshal(%v) failed err = %v", certList.CertList.TBSCertList.Issuer, err) + } + // RFC5280, 6.3.3 (b) Verify the issuer and scope of the complete CRL. + if bytes.Equal(rawIssuer, rawCRLIssuer) { + parsedCRL = certList + // Continue to find the highest number in the .rN suffix. + continue + } + } + + if parsedCRL == nil { + return nil, fmt.Errorf("fetchCrls no CRLs found for issuer") + } + return parsedCRL, nil +} + +func verifyCRL(crl *certificateListExt, rawIssuer []byte, chain []*x509.Certificate) error { + // RFC5280, 6.3.3 (f) Obtain and validateate the certification path for the issuer of the complete CRL + // We intentionally limit our CRLs to be signed with the same certificate path as the certificate + // so we can use the chain from the connection. + rawCRLIssuer, err := asn1.Marshal(crl.CertList.TBSCertList.Issuer) + if err != nil { + return fmt.Errorf("asn1.Marshal(%v) failed err = %v", crl.CertList.TBSCertList.Issuer, err) + } + + for _, c := range chain { + // Use the key where the subject and KIDs match. + // This departs from RFC4158, 3.5.12 which states that KIDs + // cannot eliminate certificates, but RFC5280, 5.2.1 states that + // "Conforming CRL issuers MUST use the key identifier method, and MUST + // include this extension in all CRLs issued." + // So, this is much simpler than RFC4158 and should be compatible. + if bytes.Equal(c.SubjectKeyId, crl.AuthorityKeyID) && bytes.Equal(c.RawSubject, rawCRLIssuer) { + // RFC5280, 6.3.3 (g) Validate signature. + return c.CheckCRLSignature(crl.CertList) + } + } + return fmt.Errorf("verifyCRL: No certificates mached CRL issuer (%v)", crl.CertList.TBSCertList.Issuer) +} diff --git a/security/advancedtls/crl_test.go b/security/advancedtls/crl_test.go new file mode 100644 index 000000000000..ec4483304c79 --- /dev/null +++ b/security/advancedtls/crl_test.go @@ -0,0 +1,718 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package advancedtls + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/hex" + "encoding/pem" + "fmt" + "io/ioutil" + "math/big" + "net" + "os" + "path" + "strings" + "testing" + "time" + + lru "github.com/hashicorp/golang-lru" + "google.golang.org/grpc/security/advancedtls/testdata" +) + +func TestX509NameHash(t *testing.T) { + nameTests := []struct { + in pkix.Name + out string + }{ + { + in: pkix.Name{ + Country: []string{"US"}, + Organization: []string{"Example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{" us"}, + Organization: []string{"example"}, + }, + out: "9cdd41ff", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"BoringSSL"}, + }, + out: "c24414d9", + }, + { + in: pkix.Name{ + SerialNumber: "87f4514475ba0a2b", + }, + out: "9dc713cd", + }, + { + in: pkix.Name{ + Country: []string{"US"}, + Province: []string{"California"}, + Locality: []string{"Mountain View"}, + Organization: []string{"Google LLC"}, + OrganizationalUnit: []string{"Production", "campus-sln"}, + CommonName: "Root CA (2021-02-02T07:30:36-08:00)", + }, + out: "0b35a562", + }, + { + in: pkix.Name{ + ExtraNames: []pkix.AttributeTypeAndValue{ + {Type: asn1.ObjectIdentifier{5, 5, 5, 5}, Value: "aaaa"}, + }, + }, + out: "eea339da", + }, + } + for _, tt := range nameTests { + t.Run(tt.in.String(), func(t *testing.T) { + h := x509NameHash(tt.in.ToRDNSequence()) + if h != tt.out { + t.Errorf("x509NameHash(%v): Got %v wanted %v", tt.in, h, tt.out) + } + }) + } +} + +func TestUnsupportedCRLs(t *testing.T) { + crlBytesSomeReasons := []byte(`-----BEGIN X509 CRL----- +MIIEeDCCA2ACAQEwDQYJKoZIhvcNAQELBQAwQjELMAkGA1UEBhMCVVMxHjAcBgNV +BAoTFUdvb2dsZSBUcnVzdCBTZXJ2aWNlczETMBEGA1UEAxMKR1RTIENBIDFPMRcN +MjEwNDI2MTI1OTQxWhcNMjEwNTA2MTE1OTQwWjCCAn0wIgIRAPOOG3L4VLC7CAAA +AABxQgEXDTIxMDQxOTEyMTgxOFowIQIQUK0UwBZkVdQIAAAAAHFCBRcNMjEwNDE5 +MTIxODE4WjAhAhBRIXBJaKoQkQgAAAAAcULHFw0yMTA0MjAxMjE4MTdaMCICEQCv +qQWUq5UxmQgAAAAAcULMFw0yMTA0MjAxMjE4MTdaMCICEQDdv5k1kKwKTQgAAAAA +cUOQFw0yMTA0MjExMjE4MTZaMCICEQDGIEfR8N9sEAgAAAAAcUOWFw0yMTA0MjEx +MjE4MThaMCECEBHgbLXlj5yUCAAAAABxQ/IXDTIxMDQyMTIzMDAyNlowIQIQE1wT +2GGYqKwIAAAAAHFD7xcNMjEwNDIxMjMwMDI5WjAiAhEAo/bSyDjpVtsIAAAAAHFE +txcNMjEwNDIyMjMwMDI3WjAhAhARdCrSrHE0dAgAAAAAcUS/Fw0yMTA0MjIyMzAw +MjhaMCECEHONohfWn3wwCAAAAABxRX8XDTIxMDQyMzIzMDAyOVowIgIRAOYkiUPA +os4vCAAAAABxRYgXDTIxMDQyMzIzMDAyOFowIQIQRNTow5Eg2gEIAAAAAHFGShcN +MjEwNDI0MjMwMDI2WjAhAhBX32dH4/WQ6AgAAAAAcUZNFw0yMTA0MjQyMzAwMjZa +MCICEQDHnUM1vsaP/wgAAAAAcUcQFw0yMTA0MjUyMzAwMjZaMCECEEm5rvmL8sj6 +CAAAAABxRxQXDTIxMDQyNTIzMDAyN1owIQIQW16OQs4YQYkIAAAAAHFIABcNMjEw +NDI2MTI1NDA4WjAhAhAhSohpYsJtDQgAAAAAcUgEFw0yMTA0MjYxMjU0MDlaoGkw +ZzAfBgNVHSMEGDAWgBSY0fhuEOvPm+xgnxiQG6DrfQn9KzALBgNVHRQEBAICBngw +NwYDVR0cAQH/BC0wK6AmoCSGImh0dHA6Ly9jcmwucGtpLmdvb2cvR1RTMU8xY29y +ZS5jcmyBAf8wDQYJKoZIhvcNAQELBQADggEBADPBXbxVxMJ1HC7btXExRUpJHUlU +YbeCZGx6zj5F8pkopbmpV7cpewwhm848Fx4VaFFppZQZd92O08daEC6aEqoug4qF +z6ZrOLzhuKfpW8E93JjgL91v0FYN7iOcT7+ERKCwVEwEkuxszxs7ggW6OJYJNvHh +priIdmcPoiQ3ZrIRH0vE3BfUcNXnKFGATWuDkiRI0I4A5P7NiOf+lAuGZet3/eom +0chgts6sdau10GfeUpHUd4f8e93cS/QeLeG16z7LC8vRLstU3m3vrknpZbdGqSia +97w66mqcnQh9V0swZiEnVLmLufaiuDZJ+6nUzSvLqBlb/ei3T/tKV0BoKJA= +-----END X509 CRL-----`) + + crlBytesIndirect := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + + var tests = []struct { + desc string + in []byte + }{ + { + desc: "some reasons", + in: crlBytesSomeReasons, + }, + { + desc: "indirect", + in: crlBytesIndirect, + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + crl, err := x509.ParseCRL(tt.in) + if err != nil { + t.Fatal(err) + } + if _, err := parseCRLExtensions(crl); err == nil { + t.Error("expected error got ok") + } + }) + } +} + +func TestCheckCertRevocation(t *testing.T) { + dummyCrlFile := []byte(`-----BEGIN X509 CRL----- +MIIDGjCCAgICAQEwDQYJKoZIhvcNAQELBQAwdjELMAkGA1UEBhMCVVMxEzARBgNV +BAgTCkNhbGlmb3JuaWExFDASBgNVBAoTC1Rlc3RpbmcgTHRkMSowKAYDVQQLEyFU +ZXN0aW5nIEx0ZCBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkxEDAOBgNVBAMTB1Rlc3Qg +Q0EXDTIxMDExNjAyMjAxNloXDTIxMDEyMDA2MjAxNlowgfIwbAIBAhcNMjEwMTE2 +MDIyMDE2WjBYMAoGA1UdFQQDCgEEMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQG +EwNVU0ExDTALBgNVBAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0 +MTAgAgEDFw0yMTAxMTYwMjIwMTZaMAwwCgYDVR0VBAMKAQEwYAIBBBcNMjEwMTE2 +MDIyMDE2WjBMMEoGA1UdHQEB/wRAMD6kPDA6MQwwCgYDVQQGEwNVU0ExDTALBgNV +BAcTBGhlcmUxCzAJBgNVBAoTAnVzMQ4wDAYDVQQDEwVUZXN0MqBjMGEwHwYDVR0j +BBgwFoAURJSDWAOfhGCryBjl8dsQjBitl3swCgYDVR0UBAMCAQEwMgYDVR0cAQH/ +BCgwJqAhoB+GHWh0dHA6Ly9jcmxzLnBraS5nb29nL3Rlc3QuY3JshAH/MA0GCSqG +SIb3DQEBCwUAA4IBAQBVXX67mr2wFPmEWCe6mf/wFnPl3xL6zNOl96YJtsd7ulcS +TEbdJpaUnWFQ23+Tpzdj/lI2aQhTg5Lvii3o+D8C5r/Jc5NhSOtVJJDI/IQLh4pG +NgGdljdbJQIT5D2Z71dgbq1ocxn8DefZIJjO3jp8VnAm7AIMX2tLTySzD2MpMeMq +XmcN4lG1e4nx+xjzp7MySYO42NRY3LkphVzJhu3dRBYhBKViRJxw9hLttChitJpF +6Kh6a0QzrEY/QDJGhE1VrAD2c5g/SKnHPDVoCWo4ACIICi76KQQSIWfIdp4W/SY3 +qsSIp8gfxSyzkJP+Ngkm2DdLjlJQCZ9R0MZP9Xj4 +-----END X509 CRL-----`) + crl, err := x509.ParseCRL(dummyCrlFile) + if err != nil { + t.Fatalf("x509.ParseCRL(dummyCrlFile) failed: %v", err) + } + crlExt := &certificateListExt{CertList: crl} + var crlIssuer pkix.Name + crlIssuer.FillFromRDNSequence(&crl.TBSCertList.Issuer) + + var revocationTests = []struct { + desc string + in x509.Certificate + revoked RevocationStatus + }{ + { + desc: "Single revoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked no entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test1", + }, + SerialNumber: big.NewInt(3), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Revoked new entry issuer", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(4), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationRevoked, + }, + { + desc: "Single unrevoked", + in: x509.Certificate{ + Issuer: pkix.Name{ + Country: []string{"USA"}, + Locality: []string{"here"}, + Organization: []string{"us"}, + CommonName: "Test2", + }, + SerialNumber: big.NewInt(1), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + { + desc: "Single unrevoked Issuer", + in: x509.Certificate{ + Issuer: crlIssuer, + SerialNumber: big.NewInt(2), + CRLDistributionPoints: []string{"test"}, + }, + revoked: RevocationUnrevoked, + }, + } + + for _, tt := range revocationTests { + rawIssuer, err := asn1.Marshal(tt.in.Issuer.ToRDNSequence()) + if err != nil { + t.Fatalf("asn1.Marshal(%v) failed: %v", tt.in.Issuer.ToRDNSequence(), err) + } + tt.in.RawIssuer = rawIssuer + t.Run(tt.desc, func(t *testing.T) { + rev, err := checkCertRevocation(&tt.in, crlExt) + if err != nil { + t.Errorf("checkCertRevocation(%v) err = %v", tt.in.Issuer, err) + } else if rev != tt.revoked { + t.Errorf("checkCertRevocation(%v(%v)) returned %v wanted %v", + tt.in.Issuer, tt.in.SerialNumber, rev, tt.revoked) + } + }) + } +} + +func makeChain(t *testing.T, name string) []*x509.Certificate { + t.Helper() + + certChain := make([]*x509.Certificate, 0) + + rest, err := ioutil.ReadFile(name) + if err != nil { + t.Fatalf("ioutil.ReadFile(%v) failed %v", name, err) + } + for len(rest) > 0 { + var block *pem.Block + block, rest = pem.Decode(rest) + c, err := x509.ParseCertificate(block.Bytes) + if err != nil { + t.Fatalf("ParseCertificate error %v", err) + } + t.Logf("Parsed Cert sub = %v iss = %v", c.Subject, c.Issuer) + certChain = append(certChain, c) + } + return certChain +} + +func loadCRL(t *testing.T, path string) *pkix.CertificateList { + b, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("readFile(%v) failed err = %v", path, err) + } + crl, err := x509.ParseCRL(b) + if err != nil { + t.Fatalf("ParseCrl(%v) failed err = %v", path, err) + } + return crl +} + +func TestCachedCRL(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + val interface{} + ok bool + }{ + { + desc: "Valid", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(time.Hour), + }, + }}, + ok: true, + }, + { + desc: "Expired", + val: &certificateListExt{ + CertList: &pkix.CertificateList{ + TBSCertList: pkix.TBSCertificateList{ + NextUpdate: time.Now().Add(-time.Hour), + }, + }}, + ok: false, + }, + { + desc: "Wrong Type", + val: "string", + ok: false, + }, + { + desc: "Empty", + val: nil, + ok: false, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + if tt.val != nil { + cache.Add(hex.EncodeToString([]byte(tt.desc)), tt.val) + } + _, ok := cachedCrl([]byte(tt.desc), cache) + if tt.ok != ok { + t.Errorf("Cache ok error expected %v vs %v", tt.ok, ok) + } + }) + } +} + +func TestGetIssuerCRLCache(t *testing.T) { + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + tests := []struct { + desc string + rawIssuer []byte + certs []*x509.Certificate + }{ + { + desc: "Valid", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + }, + { + desc: "Unverified", + rawIssuer: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1].RawIssuer, + }, + { + desc: "Not Found", + rawIssuer: []byte("not_found"), + }, + } + + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + cache.Purge() + _, err := fetchIssuerCRL("test", tt.rawIssuer, tt.certs, RevocationConfig{ + RootDir: testdata.Path("."), + Cache: cache, + }) + if err == nil && cache.Len() == 0 { + t.Error("Verified CRL not added to cache") + } + if err != nil && cache.Len() != 0 { + t.Error("Unverified CRL added to cache") + } + }) + } +} + +func TestVerifyCrl(t *testing.T) { + tampered := loadCRL(t, testdata.Path("crl/1.crl")) + // Change the signature so it won't verify + tampered.SignatureValue.Bytes[0]++ + + verifyTests := []struct { + desc string + crl *pkix.CertificateList + certs []*x509.Certificate + cert *x509.Certificate + errWant string + }{ + { + desc: "Pass intermediate", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "", + }, + { + desc: "Pass leaf", + crl: loadCRL(t, testdata.Path("crl/2.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[2], + errWant: "", + }, + { + desc: "Fail wrong cert chain", + crl: loadCRL(t, testdata.Path("crl/3.crl")), + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/revokedInt.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail no certs", + crl: loadCRL(t, testdata.Path("crl/1.crl")), + certs: []*x509.Certificate{}, + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "No certificates mached", + }, + { + desc: "Fail Tampered signature", + crl: tampered, + certs: makeChain(t, testdata.Path("crl/unrevoked.pem")), + cert: makeChain(t, testdata.Path("crl/unrevoked.pem"))[1], + errWant: "verification failure", + }, + } + + for _, tt := range verifyTests { + t.Run(tt.desc, func(t *testing.T) { + crlExt, err := parseCRLExtensions(tt.crl) + if err != nil { + t.Fatalf("parseCRLExtensions(%v) failed, err = %v", tt.crl.TBSCertList.Issuer, err) + } + err = verifyCRL(crlExt, tt.cert.RawIssuer, tt.certs) + switch { + case tt.errWant == "" && err != nil: + t.Errorf("Valid CRL did not verify err = %v", err) + case tt.errWant != "" && err == nil: + t.Error("Invalid CRL verified") + case tt.errWant != "" && !strings.Contains(err.Error(), tt.errWant): + t.Errorf("fetchIssuerCRL(_, %v, %v, _) = %v; want Contains(%v)", tt.cert.RawIssuer, tt.certs, err, tt.errWant) + } + }) + } +} + +func TestRevokedCert(t *testing.T) { + revokedIntChain := makeChain(t, testdata.Path("crl/revokedInt.pem")) + revokedLeafChain := makeChain(t, testdata.Path("crl/revokedLeaf.pem")) + validChain := makeChain(t, testdata.Path("crl/unrevoked.pem")) + cache, err := lru.New(5) + if err != nil { + t.Fatalf("lru.New: err = %v", err) + } + + var revocationTests = []struct { + desc string + in tls.ConnectionState + revoked bool + allowUndetermined bool + }{ + { + desc: "Single unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain}}, + revoked: false, + }, + { + desc: "Single revoked intermediate", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedIntChain}}, + revoked: true, + }, + { + desc: "Single revoked leaf", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain}}, + revoked: true, + }, + { + desc: "Multi one revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, revokedLeafChain}}, + revoked: false, + }, + { + desc: "Multi revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{revokedLeafChain, revokedIntChain}}, + revoked: true, + }, + { + desc: "Multi unrevoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{validChain, validChain}}, + revoked: false, + }, + { + desc: "Undetermined revoked", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: true, + }, + { + desc: "Undetermined allowed", + in: tls.ConnectionState{VerifiedChains: [][]*x509.Certificate{ + {&x509.Certificate{CRLDistributionPoints: []string{"test"}}}, + }}, + revoked: false, + allowUndetermined: true, + }, + } + + for _, tt := range revocationTests { + t.Run(tt.desc, func(t *testing.T) { + err := CheckRevocation(tt.in, RevocationConfig{ + RootDir: testdata.Path("crl"), + AllowUndetermined: tt.allowUndetermined, + Cache: cache, + }) + t.Logf("CheckRevocation err = %v", err) + if tt.revoked && err == nil { + t.Error("Revoked certificate chain was allowed") + } else if !tt.revoked && err != nil { + t.Error("Unrevoked certificate not allowed") + } + }) + } +} + +func setupTLSConn(t *testing.T) (net.Listener, *x509.Certificate, *ecdsa.PrivateKey) { + t.Helper() + templ := x509.Certificate{ + SerialNumber: big.NewInt(5), + BasicConstraintsValid: true, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + IsCA: true, + Subject: pkix.Name{CommonName: "test-cert"}, + KeyUsage: x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + IPAddresses: []net.IP{net.ParseIP("::1")}, + CRLDistributionPoints: []string{"http://static.corp.google.com/crl/campus-sln/borg"}, + } + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("ecdsa.GenerateKey failed err = %v", err) + } + rawCert, err := x509.CreateCertificate(rand.Reader, &templ, &templ, key.Public(), key) + if err != nil { + t.Fatalf("x509.CreateCertificate failed err = %v", err) + } + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + t.Fatalf("x509.ParseCertificate failed err = %v", err) + } + + srvCfg := tls.Config{ + Certificates: []tls.Certificate{ + { + Certificate: [][]byte{cert.Raw}, + PrivateKey: key, + }, + }, + } + l, err := tls.Listen("tcp6", "[::1]:0", &srvCfg) + if err != nil { + t.Fatalf("tls.Listen failed err = %v", err) + } + return l, cert, key +} + +// TestVerifyConnection will setup a client/server connection and check revocation in the real TLS dialer +func TestVerifyConnection(t *testing.T) { + lis, cert, key := setupTLSConn(t) + defer func() { + lis.Close() + }() + + var handshakeTests = []struct { + desc string + revoked []pkix.RevokedCertificate + success bool + }{ + { + desc: "Empty CRL", + revoked: []pkix.RevokedCertificate{}, + success: true, + }, + { + desc: "Revoked Cert", + revoked: []pkix.RevokedCertificate{ + { + SerialNumber: cert.SerialNumber, + RevocationTime: time.Now(), + }, + }, + success: false, + }, + } + for _, tt := range handshakeTests { + t.Run(tt.desc, func(t *testing.T) { + // Accept one connection. + go func() { + conn, err := lis.Accept() + if err != nil { + t.Errorf("tls.Accept failed err = %v", err) + } else { + conn.Write([]byte("Hello, World!")) + conn.Close() + } + }() + + dir, err := ioutil.TempDir("", "crl_dir") + if err != nil { + t.Fatalf("ioutil.TempDir failed err = %v", err) + } + defer os.RemoveAll(dir) + + crl, err := cert.CreateCRL(rand.Reader, key, tt.revoked, time.Now(), time.Now().Add(time.Hour)) + if err != nil { + t.Fatalf("templ.CreateCRL failed err = %v", err) + } + + err = ioutil.WriteFile(path.Join(dir, fmt.Sprintf("%s.r0", x509NameHash(cert.Subject.ToRDNSequence()))), crl, 0777) + if err != nil { + t.Fatalf("ioutil.WriteFile failed err = %v", err) + } + + cp := x509.NewCertPool() + cp.AddCert(cert) + cliCfg := tls.Config{ + RootCAs: cp, + VerifyConnection: func(cs tls.ConnectionState) error { + return CheckRevocation(cs, RevocationConfig{RootDir: dir}) + }, + } + conn, err := tls.Dial(lis.Addr().Network(), lis.Addr().String(), &cliCfg) + t.Logf("tls.Dial err = %v", err) + if tt.success && err != nil { + t.Errorf("Expected success got err = %v", err) + } + if !tt.success && err == nil { + t.Error("Expected error, but got success") + } + if err == nil { + conn.Close() + } + }) + } +} diff --git a/security/advancedtls/examples/go.mod b/security/advancedtls/examples/go.mod index f141d4c05c81..20ed81e24d38 100644 --- a/security/advancedtls/examples/go.mod +++ b/security/advancedtls/examples/go.mod @@ -3,7 +3,7 @@ module google.golang.org/grpc/security/advancedtls/examples go 1.15 require ( - google.golang.org/grpc v1.36.0 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b google.golang.org/grpc/security/advancedtls v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/examples/go.sum b/security/advancedtls/examples/go.sum index c24d070fa7f1..004437a7ea66 100644 --- a/security/advancedtls/examples/go.sum +++ b/security/advancedtls/examples/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -27,9 +29,12 @@ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= diff --git a/security/advancedtls/go.mod b/security/advancedtls/go.mod index 90dee4a46fd0..75527018ee78 100644 --- a/security/advancedtls/go.mod +++ b/security/advancedtls/go.mod @@ -4,7 +4,8 @@ go 1.14 require ( github.com/google/go-cmp v0.5.1 // indirect - google.golang.org/grpc v1.36.0 + github.com/hashicorp/golang-lru v0.5.4 + google.golang.org/grpc v1.38.0 google.golang.org/grpc/examples v0.0.0-20201112215255-90f1b3ee835b ) diff --git a/security/advancedtls/go.sum b/security/advancedtls/go.sum index c24d070fa7f1..004437a7ea66 100644 --- a/security/advancedtls/go.sum +++ b/security/advancedtls/go.sum @@ -1,6 +1,8 @@ cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/cncf/xds/go v0.0.0-20210312221358-fbca930ec8ed/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -27,9 +29,12 @@ github.com/google/go-cmp v0.5.1 h1:JFrFEBb2xKufg6XkJsJr+WbKb4FQlURi5RUcBveYu9k= github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= diff --git a/security/advancedtls/testdata/crl/0b35a562.r0 b/security/advancedtls/testdata/crl/0b35a562.r0 new file mode 120000 index 000000000000..1a84eabdfc72 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/0b35a562.r1 b/security/advancedtls/testdata/crl/0b35a562.r1 new file mode 120000 index 000000000000..6e6f10978918 --- /dev/null +++ b/security/advancedtls/testdata/crl/0b35a562.r1 @@ -0,0 +1 @@ +1.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/1.crl b/security/advancedtls/testdata/crl/1.crl new file mode 100644 index 000000000000..5b12ded4a66f --- /dev/null +++ b/security/advancedtls/testdata/crl/1.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDB9WEPBPHEo5xjCv8CT9okockJJnkLDOus6FypVLqj5QIgYw9/PYLwb41/Uc+4 +LLTAsfdDWh7xBJmqvVQglMoJOEc= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/1ab871c8.r0 b/security/advancedtls/testdata/crl/1ab871c8.r0 new file mode 120000 index 000000000000..f2cd877e7edb --- /dev/null +++ b/security/advancedtls/testdata/crl/1ab871c8.r0 @@ -0,0 +1 @@ +2.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/2.crl b/security/advancedtls/testdata/crl/2.crl new file mode 100644 index 000000000000..5ca9afd71419 --- /dev/null +++ b/security/advancedtls/testdata/crl/2.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMDozNi0wODowMCkX +DTIxMDIwMjE1MzAzNloXDTIxMDIwOTE1MzAzNlqgLzAtMB8GA1UdIwQYMBaAFBjo +V5Jnk/gp1k7fmWwkvTk/cF/IMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +IQDgjA1Vj/pNFtNRL0vFEdapmFoArHM2+rn4IiP8jYLsCAIgAj2KEHbbtJ3zl5XP +WVW6ZyW7r3wIX+Bt3vLJWPrQtf8= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/3.crl b/security/advancedtls/testdata/crl/3.crl new file mode 100644 index 000000000000..d37ad2247f59 --- /dev/null +++ b/security/advancedtls/testdata/crl/3.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFowJzAlAhQAroEYW855BRqTrlov +5cBCGvkutxcNMjEwMjAyMTUzMTU0WqAvMC0wHwYDVR0jBBgwFoAUeq/TQ959KbWk +/um08jSTXogXpWUwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgaSOIhJDg +wOLYlbXkmxW0cqy/AfOUNYbz5D/8/FfvhosCICftg7Vzlu0Nh83jikyjy+wtkiJt +ZYNvGFQ3Sp2L3A9e +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/4.crl b/security/advancedtls/testdata/crl/4.crl new file mode 100644 index 000000000000..d4ee6f7cf186 --- /dev/null +++ b/security/advancedtls/testdata/crl/4.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBYDCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkX +DTIxMDIwMjE1MzE1NFoXDTIxMDIwOTE1MzE1NFqgLzAtMB8GA1UdIwQYMBaAFIVn +8tIFgZpIdhomgYJ2c5ULLzpSMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0gAMEUC +ICupTvOqgAyRa1nn7+Pe/1vvlJPAQ8gUfTQsQ6XX3v6oAiEA08B2PsK6aTEwzjry +pXqhlUNZFzgaXrVVQuEJbyJ1qoU= +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/5.crl b/security/advancedtls/testdata/crl/5.crl new file mode 100644 index 000000000000..d1c24f0f25a3 --- /dev/null +++ b/security/advancedtls/testdata/crl/5.crl @@ -0,0 +1,10 @@ +-----BEGIN X509 CRL----- +MIIBXzCCAQYCAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1qgLzAtMB8GA1UdIwQYMBaAFN+g +xTAtSTlb5Qqvrbp4rZtsaNzqMAoGA1UdFAQDAgEAMAoGCCqGSM49BAMCA0cAMEQC +IHrRKjieY7w7gxvpkJAdszPZBlaSSp/c9wILutBTy7SyAiAwhaHfgas89iRfaBs2 +EhGIeK39A+kSzqu6qEQBHpK36g== +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/6.crl b/security/advancedtls/testdata/crl/6.crl new file mode 100644 index 000000000000..87ef378f6aba --- /dev/null +++ b/security/advancedtls/testdata/crl/6.crl @@ -0,0 +1,11 @@ +-----BEGIN X509 CRL----- +MIIBiDCCAS8CAQEwCgYIKoZIzj0EAwIwgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQI +EwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpH +b29nbGUgTExDMSYwEQYDVQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNs +bjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkX +DTIxMDIwMjE1MzI1N1oXDTIxMDIwOTE1MzI1N1owJzAlAhQAxSe/pGmyvzN7mxm5 +6ZJTYUXYuhcNMjEwMjAyMTUzMjU3WqAvMC0wHwYDVR0jBBgwFoAUpZ30UJXB4lI9 +j2SzodCtRFckrRcwCgYDVR0UBAMCAQEwCgYIKoZIzj0EAwIDRwAwRAIgRg3u7t3b +oyV5FhMuGGzWnfIwnKclpT8imnp8tEN253sCIFUY7DjiDohwu4Zup3bWs1OaZ3q3 +cm+j0H/oe8zzCAgp +-----END X509 CRL----- diff --git a/security/advancedtls/testdata/crl/71eac5a2.r0 b/security/advancedtls/testdata/crl/71eac5a2.r0 new file mode 120000 index 000000000000..9f37924cae0c --- /dev/null +++ b/security/advancedtls/testdata/crl/71eac5a2.r0 @@ -0,0 +1 @@ +4.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/7a1799af.r0 b/security/advancedtls/testdata/crl/7a1799af.r0 new file mode 120000 index 000000000000..f34df5b59c01 --- /dev/null +++ b/security/advancedtls/testdata/crl/7a1799af.r0 @@ -0,0 +1 @@ +3.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/8828a7e6.r0 b/security/advancedtls/testdata/crl/8828a7e6.r0 new file mode 120000 index 000000000000..70bead214cc3 --- /dev/null +++ b/security/advancedtls/testdata/crl/8828a7e6.r0 @@ -0,0 +1 @@ +6.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/README.md b/security/advancedtls/testdata/crl/README.md new file mode 100644 index 000000000000..00cb09c31928 --- /dev/null +++ b/security/advancedtls/testdata/crl/README.md @@ -0,0 +1,48 @@ +# CRL Test Data + +This directory contains cert chains and CRL files for revocation testing. + +To print the chain, use a command like, + +```shell +openssl crl2pkcs7 -nocrl -certfile security/crl/x509/client/testdata/revokedLeaf.pem | openssl pkcs7 -print_certs -text -noout +``` + +The crl file symlinks are generated with `openssl rehash` + +## unrevoked.pem + +A certificate chain with CRL files and unrevoked certs + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:30:36-08:00) + * 1.crl + +NOTE: 1.crl file is symlinked with 5.crl to simulate two issuers that hash to +the same value to test that loading multiple files works. + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:30:36-08:00) + * 2.crl + +## revokedInt.pem + +Certificate chain where the intermediate is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:31:54-08:00) + * 3.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:31:54-08:00) + * 4.crl + +## revokedLeaf.pem + +Certificate chain where the leaf is revoked + +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=Root CA (2021-02-02T07:32:57-08:00) + * 5.crl +* Subject: C=US, ST=California, L=Mountain View, O=Google LLC, OU=Production, + OU=campus-sln, CN=node CA (2021-02-02T07:32:57-08:00) + * 6.crl diff --git a/security/advancedtls/testdata/crl/deee447d.r0 b/security/advancedtls/testdata/crl/deee447d.r0 new file mode 120000 index 000000000000..1a84eabdfc72 --- /dev/null +++ b/security/advancedtls/testdata/crl/deee447d.r0 @@ -0,0 +1 @@ +5.crl \ No newline at end of file diff --git a/security/advancedtls/testdata/crl/revokedInt.pem b/security/advancedtls/testdata/crl/revokedInt.pem new file mode 100644 index 000000000000..8b7282ff8221 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedInt.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITAWjKwm2dNQvkO62Jgyr5rAvVQzAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMTo1NC0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAAQhA0/puhTtSxbVVHseVhL2z7QhpPyJs5Q4beKi7tpaYRDmVn6p +Phh+jbRzg8Qj4gKI/Q1rrdm4rKer63LHpdWdo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQUeq/TQ959KbWk/um08jSTXogXpWUwHwYDVR0jBBgwFoAUeq/T +Q959KbWk/um08jSTXogXpWUwLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgOSQZvyDPQwVOWnpF +zWvI+DS2yXIj/2T2EOvJz2XgcK4CIQCL0mh/+DxLiO4zzbInKr0mxpGSxSeZCUk7 +1ZF7AeLlbw== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAK6BGFvOeQUak65aL+XAQhr5LrcwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMTo1NC0wODowMCkwIBcNMjEwMjAyMTUzMTU0WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzE6NTQtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEye6UOlBos8Q3FFBiLahD9BaLTA18bO4MTPyv35T3lppvxD5X +U/AnEllOnx5OMtMjMBbIQjSkMbiQ9xNXoSqB6aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUhWfy0gWBmkh2GiaBgnZzlQsvOlIwHwYDVR0jBBgwFoAU +eq/TQ959KbWk/um08jSTXogXpWUwMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA79rPu6ZO1/0qB6RxL7jVz1200 +UTo8ioB4itbTzMnJqAIgJqp/Rc8OhpsfzQX8XnIIkl+SewT+tOxJT1MHVNMlVhc= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIC0DCCAnWgAwIBAgITXQ2c/C27OGqk4Pbu+MNJlOtpYTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMxOjU0LTA4OjAwKTAgFw0yMTAyMDIxNTMxNTRaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABN2/1le5d3hS/piw +hrNMHjd7gPEjzXwtuXQTzdV+aaeOf3ldnC6OnEF/bggym9MldQSJZLXPYSaoj430 +Vu5PRNejggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTEewP3JgrJPekWWGGjChVqaMhaqTAfBgNV +HSMEGDAWgBSFZ/LSBYGaSHYaJoGCdnOVCy86UjBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEA9w4qp3nHpXo+6d7mZc69 +QoALfP5ynfBCArt8bAlToo8CIQCgc/lTfl2BtBko+7h/w6pKxLeuoQkvCL5gHFyK +LXE6vA== +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/revokedLeaf.pem b/security/advancedtls/testdata/crl/revokedLeaf.pem new file mode 100644 index 000000000000..b7541abf6214 --- /dev/null +++ b/security/advancedtls/testdata/crl/revokedLeaf.pem @@ -0,0 +1,59 @@ +-----BEGIN CERTIFICATE----- +MIIDAzCCAqmgAwIBAgITTwodm6C4ZabFVUVa5yBw0TbzJTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNSb290IENBICgyMDIx +LTAyLTAyVDA3OjMyOjU3LTA4OjAwKTAgFw0yMTAyMDIxNTMyNTdaGA85OTk5MTIz +MTIzNTk1OVowgaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYw +FAYDVQQHEw1Nb3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYD +VQQLEwpQcm9kdWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9v +dCBDQSAoMjAyMS0wMi0wMlQwNzozMjo1Ny0wODowMCkwWTATBgcqhkjOPQIBBggq +hkjOPQMBBwNCAARoZnzQWvAoyhvCLA2cFIK17khSaA9aA+flS5X9fLRt4RsfPCx3 +kim7wYKQSmBhQdc1UM4h3969r1c1Fvsh2H9qo4GzMIGwMA4GA1UdDwEB/wQEAwIB +BjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUwAwEB +/zAdBgNVHQ4EFgQU36DFMC1JOVvlCq+tunitm2xo3OowHwYDVR0jBBgwFoAU36DF +MC1JOVvlCq+tunitm2xo3OowLgYDVR0RBCcwJYYjc3BpZmZlOi8vY2FtcHVzLXNs +bi5wcm9kLmdvb2dsZS5jb20wCgYIKoZIzj0EAwIDSAAwRQIgN7S9dQOQzNih92ag +7c5uQxuz+M6wnxWj/uwGQIIghRUCIQD2UDH6kkRSYQuyP0oN7XYO3XFjmZ2Yer6m +1ZS8fyWYYA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDjTCCAzKgAwIBAgIUAOmArBu9gihLTlqP3W7Et0UoocEwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzI6NTctMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEfrgVEVQfSEFeCF1/FGeW7oq0yxecenT1BESfj4Z0zJ8p7P9W +bj1o6Rn6dUNlEhGrx7E3/4NFJ0cL1BSNGHkjiqOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUpZ30UJXB4lI9j2SzodCtRFckrRcwHwYDVR0jBBgwFoAU +36DFMC1JOVvlCq+tunitm2xo3OowMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNJADBGAiEAnuONgMqmbBlj4ibw5BgDtZUM +pboACSFJtEOJu4Yqjt0CIQDI5193J4wUcAY0BK0vO9rRfbNOIc+4ke9ieBDPSuhm +mA== +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnagAwIBAgIUAMUnv6Rpsr8ze5sZuemSU2FF2LowCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjbm9kZSBDQSAoMjAy +MS0wMi0wMlQwNzozMjo1Ny0wODowMCkwIBcNMjEwMjAyMTUzMjU3WhgPOTk5OTEy +MzEyMzU5NTlaMAAwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAASCmYiIHUux5WFz +S0ksJzAPL7YTEh5o5MdXgLPB/WM6x9sVsQDSYU0PF5qc9vPNhkQzGBW79dkBnxhW +AGJkFr1Po4IBJDCCASAwDgYDVR0PAQH/BAQDAgeAMB0GA1UdJQQWMBQGCCsGAQUF +BwMCBggrBgEFBQcDATAdBgNVHQ4EFgQUCR1CGEdlks0qcxCExO0rP1B/Z7UwHwYD +VR0jBBgwFoAUpZ30UJXB4lI9j2SzodCtRFckrRcwawYDVR0RAQH/BGEwX4IWanph +YjEyLnByb2QuZ29vZ2xlLmNvbYZFc3BpZmZlOi8vY3Njcy10ZWFtLm5vZGUuY2Ft +cHVzLXNsbi5wcm9kLmdvb2dsZS5jb20vcm9sZS9ib3JnLWFkbWluLWNvMEIGA1Ud +HwQ7MDkwN6A1oDOGMWh0dHA6Ly9zdGF0aWMuY29ycC5nb29nbGUuY29tL2NybC9j +YW1wdXMtc2xuL25vZGUwCgYIKoZIzj0EAwIDRwAwRAIgK9vQYNoL8HlEwWv89ioG +aQ1+8swq6Bo/5mJBrdVLvY8CIGxo6M9vJkPdObmetWNC+lmKuZDoqJWI0AAmBT2J +mR2r +-----END CERTIFICATE----- diff --git a/security/advancedtls/testdata/crl/unrevoked.pem b/security/advancedtls/testdata/crl/unrevoked.pem new file mode 100644 index 000000000000..5c5fc58a7a5e --- /dev/null +++ b/security/advancedtls/testdata/crl/unrevoked.pem @@ -0,0 +1,58 @@ +-----BEGIN CERTIFICATE----- +MIIDBDCCAqqgAwIBAgIUALy864QhnkTdceLH52k2XVOe8IQwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI1Jv +b3QgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEYv/JS5hQ5kIgdKqYZWTKCO/6gloHAmIb1G8lmY0oXLXYNHQ4 +qHN7/pPtlcHQp0WK/hM8IGvgOUDoynA8mj0H9KOBszCBsDAOBgNVHQ8BAf8EBAMC +AQYwHQYDVR0lBBYwFAYIKwYBBQUHAwEGCCsGAQUFBwMCMA8GA1UdEwEB/wQFMAMB +Af8wHQYDVR0OBBYEFPQNtnCIBcG4ReQgoVi0kPgTROseMB8GA1UdIwQYMBaAFPQN +tnCIBcG4ReQgoVi0kPgTROseMC4GA1UdEQQnMCWGI3NwaWZmZTovL2NhbXB1cy1z +bG4ucHJvZC5nb29nbGUuY29tMAoGCCqGSM49BAMCA0gAMEUCIQDwBn20DB4X/7Uk +Q5BR8JxQYUPxOfvuedjfeA8bPvQ2FwIgOEWa0cXJs1JxarILJeCXtdXvBgu6LEGQ +3Pk/bgz8Gek= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIIDizCCAzKgAwIBAgIUAM/6RKQ7Vke0i4xp5LaAqV73cmIwCgYIKoZIzj0EAwIw +gaUxCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1N +b3VudGFpbiBWaWV3MRMwEQYDVQQKEwpHb29nbGUgTExDMSYwEQYDVQQLEwpQcm9k +dWN0aW9uMBEGA1UECxMKY2FtcHVzLXNsbjEsMCoGA1UEAxMjUm9vdCBDQSAoMjAy +MS0wMi0wMlQwNzozMDozNi0wODowMCkwIBcNMjEwMjAyMTUzMDM2WhgPOTk5OTEy +MzEyMzU5NTlaMIGlMQswCQYDVQQGEwJVUzETMBEGA1UECBMKQ2FsaWZvcm5pYTEW +MBQGA1UEBxMNTW91bnRhaW4gVmlldzETMBEGA1UEChMKR29vZ2xlIExMQzEmMBEG +A1UECxMKUHJvZHVjdGlvbjARBgNVBAsTCmNhbXB1cy1zbG4xLDAqBgNVBAMTI25v +ZGUgQ0EgKDIwMjEtMDItMDJUMDc6MzA6MzYtMDg6MDApMFkwEwYHKoZIzj0CAQYI +KoZIzj0DAQcDQgAEllnhxmMYiUPUgRGmenbnm10gXpM94zHx3D1/HumPs6arjYuT +Zlhx81XL+g4bu4HII2qcGdP+Hqj/MMFNDI9z4aOCATowggE2MA4GA1UdDwEB/wQE +AwIBBjAdBgNVHSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDwYDVR0TAQH/BAUw +AwEB/zAdBgNVHQ4EFgQUGOhXkmeT+CnWTt+ZbCS9OT9wX8gwHwYDVR0jBBgwFoAU +9A22cIgFwbhF5CChWLSQ+BNE6x4wMwYDVR0RBCwwKoYoc3BpZmZlOi8vbm9kZS5j +YW1wdXMtc2xuLnByb2QuZ29vZ2xlLmNvbTA7BgNVHR4BAf8EMTAvoC0wK4YpY3Nj +cy10ZWFtLm5vZGUuY2FtcHVzLXNsbi5wcm9kLmdvb2dsZS5jb20wQgYDVR0fBDsw +OTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2NhbXB1 +cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNHADBEAiA86egqPw0qyapAeMGbHxrmYZYa +i5ARQsSKRmQixgYizQIgW+2iRWN6Kbqt4WcwpmGv/xDckdRXakF5Ign/WUDO5u4= +-----END CERTIFICATE----- +-----BEGIN CERTIFICATE----- +MIICzzCCAnWgAwIBAgITYjjKfYZUKQNUjNyF+hLDGpHJKTAKBggqhkjOPQQDAjCB +pTELMAkGA1UEBhMCVVMxEzARBgNVBAgTCkNhbGlmb3JuaWExFjAUBgNVBAcTDU1v +dW50YWluIFZpZXcxEzARBgNVBAoTCkdvb2dsZSBMTEMxJjARBgNVBAsTClByb2R1 +Y3Rpb24wEQYDVQQLEwpjYW1wdXMtc2xuMSwwKgYDVQQDEyNub2RlIENBICgyMDIx +LTAyLTAyVDA3OjMwOjM2LTA4OjAwKTAgFw0yMTAyMDIxNTMwMzZaGA85OTk5MTIz +MTIzNTk1OVowADBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABD4r4+nCgZExYF8v +CLvGn0lY/cmam8mAkJDXRN2Ja2t+JwaTOptPmbbXft+1NTk5gCg5wB+FJCnaV3I/ +HaxEhBWjggEkMIIBIDAOBgNVHQ8BAf8EBAMCB4AwHQYDVR0lBBYwFAYIKwYBBQUH +AwIGCCsGAQUFBwMBMB0GA1UdDgQWBBTTCjXX1Txjc00tBg/5cFzpeCSKuDAfBgNV +HSMEGDAWgBQY6FeSZ5P4KdZO35lsJL05P3BfyDBrBgNVHREBAf8EYTBfghZqemFi +MTIucHJvZC5nb29nbGUuY29thkVzcGlmZmU6Ly9jc2NzLXRlYW0ubm9kZS5jYW1w +dXMtc2xuLnByb2QuZ29vZ2xlLmNvbS9yb2xlL2JvcmctYWRtaW4tY28wQgYDVR0f +BDswOTA3oDWgM4YxaHR0cDovL3N0YXRpYy5jb3JwLmdvb2dsZS5jb20vY3JsL2Nh +bXB1cy1zbG4vbm9kZTAKBggqhkjOPQQDAgNIADBFAiBq3URViNyMLpvzZHC1Y+4L ++35guyIJfjHu08P3S8/xswIhAJtWSQ1ZtozdOzGxg7GfUo4hR+5SP6rBTgIqXEfq +48fW +-----END CERTIFICATE----- diff --git a/server.go b/server.go index 446995986c8b..de1708306417 100644 --- a/server.go +++ b/server.go @@ -844,10 +844,16 @@ func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) { // ErrConnDispatched means that the connection was dispatched away from // gRPC; those connections should be left open. if err != credentials.ErrConnDispatched { - s.mu.Lock() - s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) - s.mu.Unlock() - channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) + // In deployments where a gRPC server runs behind a cloud load + // balancer which performs regular TCP level health checks, the + // connection is closed immediately by the latter. Skipping the + // error here will help reduce log clutter. + if err != io.EOF { + s.mu.Lock() + s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err) + s.mu.Unlock() + channelz.Warningf(logger, s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err) + } rawConn.Close() } rawConn.SetDeadline(time.Time{}) @@ -897,7 +903,7 @@ func (s *Server) newHTTP2Transport(c net.Conn, authInfo credentials.AuthInfo) tr MaxHeaderListSize: s.opts.maxHeaderListSize, HeaderTableSize: s.opts.headerTableSize, } - st, err := transport.NewServerTransport("http2", c, config) + st, err := transport.NewServerTransport(c, config) if err != nil { s.mu.Lock() s.errorf("NewServerTransport(%q) failed: %v", c.RemoteAddr(), err) @@ -1109,22 +1115,24 @@ func chainUnaryServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { - return interceptors[0](ctx, req, info, getChainUnaryHandler(interceptors, 0, info, handler)) - } + chainedInt = chainUnaryInterceptors(interceptors) } s.opts.unaryInt = chainedInt } -// getChainUnaryHandler recursively generate the chained UnaryHandler -func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info *UnaryServerInfo, finalHandler UnaryHandler) UnaryHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(ctx context.Context, req interface{}) (interface{}, error) { - return interceptors[curr+1](ctx, req, info, getChainUnaryHandler(interceptors, curr+1, info, finalHandler)) +func chainUnaryInterceptors(interceptors []UnaryServerInterceptor) UnaryServerInterceptor { + return func(ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler) (interface{}, error) { + var i int + var next UnaryHandler + next = func(ctx context.Context, req interface{}) (interface{}, error) { + if i == len(interceptors)-1 { + return interceptors[i](ctx, req, info, handler) + } + i++ + return interceptors[i-1](ctx, req, info, next) + } + return next(ctx, req) } } @@ -1138,7 +1146,9 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: false, + IsServerStream: false, } sh.HandleRPC(stream.Context(), statsBegin) } @@ -1390,22 +1400,24 @@ func chainStreamServerInterceptors(s *Server) { } else if len(interceptors) == 1 { chainedInt = interceptors[0] } else { - chainedInt = func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { - return interceptors[0](srv, ss, info, getChainStreamHandler(interceptors, 0, info, handler)) - } + chainedInt = chainStreamInterceptors(interceptors) } s.opts.streamInt = chainedInt } -// getChainStreamHandler recursively generate the chained StreamHandler -func getChainStreamHandler(interceptors []StreamServerInterceptor, curr int, info *StreamServerInfo, finalHandler StreamHandler) StreamHandler { - if curr == len(interceptors)-1 { - return finalHandler - } - - return func(srv interface{}, ss ServerStream) error { - return interceptors[curr+1](srv, ss, info, getChainStreamHandler(interceptors, curr+1, info, finalHandler)) +func chainStreamInterceptors(interceptors []StreamServerInterceptor) StreamServerInterceptor { + return func(srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error { + var i int + var next StreamHandler + next = func(srv interface{}, ss ServerStream) error { + if i == len(interceptors)-1 { + return interceptors[i](srv, ss, info, handler) + } + i++ + return interceptors[i-1](srv, ss, info, next) + } + return next(srv, ss) } } @@ -1418,7 +1430,9 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp if sh != nil { beginTime := time.Now() statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: beginTime, + IsClientStream: sd.ClientStreams, + IsServerStream: sd.ServerStreams, } sh.HandleRPC(stream.Context(), statsBegin) } diff --git a/server_test.go b/server_test.go index fcfde30706c3..b15939160144 100644 --- a/server_test.go +++ b/server_test.go @@ -22,6 +22,7 @@ import ( "context" "net" "reflect" + "strconv" "strings" "testing" "time" @@ -130,3 +131,59 @@ func (s) TestStreamContext(t *testing.T) { t.Fatalf("GetStreamFromContext(%v) = %v, %t, want: %v, true", ctx, stream, ok, expectedStream) } } + +func BenchmarkChainUnaryInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]UnaryServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + ctx context.Context, req interface{}, info *UnaryServerInfo, handler UnaryHandler, + ) (interface{}, error) { + return handler(ctx, req) + }) + } + + s := NewServer(ChainUnaryInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if _, err := s.opts.unaryInt(context.Background(), nil, nil, + func(ctx context.Context, req interface{}) (interface{}, error) { + return nil, nil + }, + ); err != nil { + b.Fatal(err) + } + } + }) + } +} + +func BenchmarkChainStreamInterceptor(b *testing.B) { + for _, n := range []int{1, 3, 5, 10} { + n := n + b.Run(strconv.Itoa(n), func(b *testing.B) { + interceptors := make([]StreamServerInterceptor, 0, n) + for i := 0; i < n; i++ { + interceptors = append(interceptors, func( + srv interface{}, ss ServerStream, info *StreamServerInfo, handler StreamHandler, + ) error { + return handler(srv, ss) + }) + } + + s := NewServer(ChainStreamInterceptor(interceptors...)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := s.opts.streamInt(nil, nil, nil, func(srv interface{}, stream ServerStream) error { + return nil + }); err != nil { + b.Fatal(err) + } + } + }) + } +} diff --git a/stats/stats.go b/stats/stats.go index 63e476ee7ff8..a5ebeeb6932d 100644 --- a/stats/stats.go +++ b/stats/stats.go @@ -45,6 +45,10 @@ type Begin struct { BeginTime time.Time // FailFast indicates if this RPC is failfast. FailFast bool + // IsClientStream indicates whether the RPC is a client streaming RPC. + IsClientStream bool + // IsServerStream indicates whether the RPC is a server streaming RPC. + IsServerStream bool } // IsClient indicates if the stats information is from client side. diff --git a/stats/stats_test.go b/stats/stats_test.go index 306f2f6b8e90..dfc6edfc3d36 100644 --- a/stats/stats_test.go +++ b/stats/stats_test.go @@ -407,15 +407,17 @@ func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.StreamingOutputCallReq } type expectedData struct { - method string - serverAddr string - compression string - reqIdx int - requests []proto.Message - respIdx int - responses []proto.Message - err error - failfast bool + method string + isClientStream bool + isServerStream bool + serverAddr string + compression string + reqIdx int + requests []proto.Message + respIdx int + responses []proto.Message + err error + failfast bool } type gotData struct { @@ -456,6 +458,12 @@ func checkBegin(t *testing.T, d *gotData, e *expectedData) { t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) } } + if st.IsClientStream != e.isClientStream { + t.Fatalf("st.IsClientStream = %v, want %v", st.IsClientStream, e.isClientStream) + } + if st.IsServerStream != e.isServerStream { + t.Fatalf("st.IsServerStream = %v, want %v", st.IsServerStream, e.isServerStream) + } } func checkInHeader(t *testing.T, d *gotData, e *expectedData) { @@ -847,6 +855,9 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f err error method string + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -864,14 +875,18 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -900,12 +915,14 @@ func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []f } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() @@ -1138,6 +1155,9 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map method string err error + isClientStream bool + isServerStream bool + req proto.Message resp proto.Message e error @@ -1154,14 +1174,18 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map reqs, resp, e = te.doClientStreamCall(cc) resps = []proto.Message{resp} err = e + isClientStream = true case serverStreamRPC: method = "/grpc.testing.TestService/StreamingOutputCall" req, resps, e = te.doServerStreamCall(cc) reqs = []proto.Message{req} err = e + isServerStream = true case fullDuplexStreamRPC: method = "/grpc.testing.TestService/FullDuplexCall" reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) + isClientStream = true + isServerStream = true } if cc.success != (err == nil) { t.Fatalf("cc.success: %v, got error: %v", cc.success, err) @@ -1194,13 +1218,15 @@ func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map } expect := &expectedData{ - serverAddr: te.srvAddr, - compression: tc.compress, - method: method, - requests: reqs, - responses: resps, - failfast: cc.failfast, - err: err, + serverAddr: te.srvAddr, + compression: tc.compress, + method: method, + requests: reqs, + responses: resps, + failfast: cc.failfast, + err: err, + isClientStream: isClientStream, + isServerStream: isServerStream, } h.mu.Lock() diff --git a/stream.go b/stream.go index 1356ad08895a..049ec799f223 100644 --- a/stream.go +++ b/stream.go @@ -299,9 +299,11 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: c.failFast}) beginTime = time.Now() begin := &stats.Begin{ - Client: true, - BeginTime: beginTime, - FailFast: c.failFast, + Client: true, + BeginTime: beginTime, + FailFast: c.failFast, + IsClientStream: desc.ClientStreams, + IsServerStream: desc.ServerStreams, } sh.HandleRPC(ctx, begin) } @@ -423,12 +425,9 @@ func (a *csAttempt) newStream() error { cs.callHdr.PreviousAttempts = cs.numRetries s, err := a.t.NewStream(cs.ctx, cs.callHdr) if err != nil { - if _, ok := err.(transport.PerformedIOError); ok { - // Return without converting to an RPC error so retry code can - // inspect. - return err - } - return toRPCErr(err) + // Return without converting to an RPC error so retry code can + // inspect. + return err } cs.attempt.s = s cs.attempt.p = &parser{r: s} @@ -527,19 +526,28 @@ func (cs *clientStream) commitAttempt() { // shouldRetry returns nil if the RPC should be retried; otherwise it returns // the error that should be returned by the operation. func (cs *clientStream) shouldRetry(err error) error { - unprocessed := false if cs.attempt.s == nil { - pioErr, ok := err.(transport.PerformedIOError) - if ok { - // Unwrap error. - err = toRPCErr(pioErr.Err) - } else { - unprocessed = true + // Error from NewClientStream. + nse, ok := err.(*transport.NewStreamError) + if !ok { + // Unexpected, but assume no I/O was performed and the RPC is not + // fatal, so retry indefinitely. + return nil } - if !ok && !cs.callInfo.failFast { - // In the event of a non-IO operation error from NewStream, we - // never attempted to write anything to the wire, so we can retry - // indefinitely for non-fail-fast RPCs. + + // Unwrap and convert error. + err = toRPCErr(nse.Err) + + // Never retry DoNotRetry errors, which indicate the RPC should not be + // retried due to max header list size violation, etc. + if nse.DoNotRetry { + return err + } + + // In the event of a non-IO operation error from NewStream, we never + // attempted to write anything to the wire, so we can retry + // indefinitely. + if !nse.PerformedIO { return nil } } @@ -548,6 +556,7 @@ func (cs *clientStream) shouldRetry(err error) error { return err } // Wait for the trailers. + unprocessed := false if cs.attempt.s != nil { <-cs.attempt.s.Done() unprocessed = cs.attempt.s.Unprocessed() @@ -636,7 +645,7 @@ func (cs *clientStream) shouldRetry(err error) error { // Returns nil if a retry was performed and succeeded; error otherwise. func (cs *clientStream) retryLocked(lastErr error) error { for { - cs.attempt.finish(lastErr) + cs.attempt.finish(toRPCErr(lastErr)) if err := cs.shouldRetry(lastErr); err != nil { cs.commitAttemptLocked() return err @@ -663,7 +672,11 @@ func (cs *clientStream) withRetry(op func(a *csAttempt) error, onSuccess func()) for { if cs.committed { cs.mu.Unlock() - return op(cs.attempt) + // toRPCErr is used in case the error from the attempt comes from + // NewClientStream, which intentionally doesn't return a status + // error to allow for further inspection; all other errors should + // already be status errors. + return toRPCErr(op(cs.attempt)) } a := cs.attempt cs.mu.Unlock() diff --git a/test/end2end_test.go b/test/end2end_test.go index 552f74e1b795..3d941b187bf9 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -1453,7 +1453,7 @@ func (s) TestDetailedGoawayErrorOnAbruptClosePropagatesToRPCError(t *testing.T) if err != nil { t.Fatalf("%v.FullDuplexCall = _, %v, want _, ", ss.Client, err) } - const expectedErrorMessageSubstring = "received prior goaway: code: ENHANCE_YOUR_CALM, debug data: too_many_pings" + const expectedErrorMessageSubstring = `received prior goaway: code: ENHANCE_YOUR_CALM, debug data: "too_many_pings"` _, err = stream.Recv() close(rpcDoneOnClient) if err == nil || !strings.Contains(err.Error(), expectedErrorMessageSubstring) { @@ -7254,7 +7254,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { ":status", "403", "content-type", "application/grpc", }, - errCode: codes.Unknown, + errCode: codes.PermissionDenied, }, { // malformed grpc-status. @@ -7273,7 +7273,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "grpc-status", "0", "grpc-tags-bin", "???", }, - errCode: codes.Internal, + errCode: codes.Unavailable, }, { // gRPC status error. @@ -7282,14 +7282,14 @@ func (s) TestHTTPHeaderFrameErrorHandlingInitialHeader(t *testing.T) { "content-type", "application/grpc", "grpc-status", "3", }, - errCode: codes.InvalidArgument, + errCode: codes.Unavailable, }, } { doHTTPHeaderTest(t, test.errCode, test.header) } } -// Testing non-Trailers-only Trailers (delievered in second HEADERS frame) +// Testing non-Trailers-only Trailers (delivered in second HEADERS frame) func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { for _, test := range []struct { responseHeader []string @@ -7305,7 +7305,7 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { // trailer missing grpc-status ":status", "502", }, - errCode: codes.Unknown, + errCode: codes.Unavailable, }, { responseHeader: []string{ @@ -7317,6 +7317,18 @@ func (s) TestHTTPHeaderFrameErrorHandlingNormalTrailer(t *testing.T) { "grpc-status", "0", "grpc-status-details-bin", "????", }, + errCode: codes.Unimplemented, + }, + { + responseHeader: []string{ + ":status", "200", + "content-type", "application/grpc", + }, + trailer: []string{ + // malformed grpc-status-details-bin field + "grpc-status", "0", + "grpc-status-details-bin", "????", + }, errCode: codes.Internal, }, } { @@ -7668,3 +7680,89 @@ func (s) TestClientSettingsFloodCloseConn(t *testing.T) { s.GracefulStop() timer.Stop() } + +// TestDeadlineSetOnConnectionOnClientCredentialHandshake tests that there is a deadline +// set on the net.Conn when a credential handshake happens in http2_client. +func (s) TestDeadlineSetOnConnectionOnClientCredentialHandshake(t *testing.T) { + lis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen: %v", err) + } + connCh := make(chan net.Conn, 1) + go func() { + defer close(connCh) + conn, err := lis.Accept() + if err != nil { + t.Errorf("Error accepting connection: %v", err) + return + } + connCh <- conn + }() + defer func() { + conn := <-connCh + if conn != nil { + conn.Close() + } + }() + deadlineCh := testutils.NewChannel() + cvd := &credentialsVerifyDeadline{ + deadlineCh: deadlineCh, + } + dOpt := grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + return &infoConn{Conn: conn}, nil + }) + cc, err := grpc.Dial(lis.Addr().String(), dOpt, grpc.WithTransportCredentials(cvd)) + if err != nil { + t.Fatalf("Failed to dial: %v", err) + } + defer cc.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + deadline, err := deadlineCh.Receive(ctx) + if err != nil { + t.Fatalf("Error receiving from credsInvoked: %v", err) + } + // Default connection timeout is 20 seconds, so if the deadline exceeds now + // + 18 seconds it should be valid. + if !deadline.(time.Time).After(time.Now().Add(time.Second * 18)) { + t.Fatalf("Connection did not have deadline set.") + } +} + +type infoConn struct { + net.Conn + deadline time.Time +} + +func (c *infoConn) SetDeadline(t time.Time) error { + c.deadline = t + return c.Conn.SetDeadline(t) +} + +type credentialsVerifyDeadline struct { + deadlineCh *testutils.Channel +} + +func (cvd *credentialsVerifyDeadline) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + cvd.deadlineCh.Send(rawConn.(*infoConn).deadline) + return rawConn, nil, nil +} + +func (cvd *credentialsVerifyDeadline) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{} +} +func (cvd *credentialsVerifyDeadline) Clone() credentials.TransportCredentials { + return cvd +} +func (cvd *credentialsVerifyDeadline) OverrideServerName(s string) error { + return nil +} diff --git a/test/kokoro/xds_k8s.cfg b/test/kokoro/xds_k8s.cfg index e9a82b002b9e..4d5e019991f6 100644 --- a/test/kokoro/xds_k8s.cfg +++ b/test/kokoro/xds_k8s.cfg @@ -1,7 +1,7 @@ # Config file for internal CI # Location of the continuous shell script in repository. -build_file: "grpc-go/test/kokoro/xds-k8s.sh" +build_file: "grpc-go/test/kokoro/xds_k8s.sh" timeout_mins: 120 action { diff --git a/test/kokoro/xds_k8s.sh b/test/kokoro/xds_k8s.sh index 4d1fda18a9f9..dbe200962c18 100755 --- a/test/kokoro/xds_k8s.sh +++ b/test/kokoro/xds_k8s.sh @@ -129,7 +129,7 @@ run_test() { main() { local script_dir script_dir="$(dirname "$0")" - # shellcheck source=tools/internal_ci/linux/grpc_xds_k8s_install_test_driver.sh + # shellcheck source=test/kokoro/xds_k8s_install_test_driver.sh source "${script_dir}/xds_k8s_install_test_driver.sh" set -x if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then diff --git a/test/kokoro/xds_k8s_install_test_driver.sh b/test/kokoro/xds_k8s_install_test_driver.sh index 54d9127608c3..a342bd995f92 100755 --- a/test/kokoro/xds_k8s_install_test_driver.sh +++ b/test/kokoro/xds_k8s_install_test_driver.sh @@ -19,7 +19,7 @@ set -eo pipefail readonly PYTHON_VERSION="3.6" # Test driver readonly TEST_DRIVER_REPO_NAME="grpc" -readonly TEST_DRIVER_REPO_URL="https://github.com/grpc/grpc.git" +readonly TEST_DRIVER_REPO_URL="https://github.com/${TEST_DRIVER_REPO_OWNER:-grpc}/grpc.git" readonly TEST_DRIVER_BRANCH="${TEST_DRIVER_BRANCH:-master}" readonly TEST_DRIVER_PATH="tools/run_tests/xds_k8s_test_driver" readonly TEST_DRIVER_PROTOS_PATH="src/proto/grpc/testing" diff --git a/test/kokoro/xds_url_map.cfg b/test/kokoro/xds_url_map.cfg new file mode 100644 index 000000000000..f6fd84a419af --- /dev/null +++ b/test/kokoro/xds_url_map.cfg @@ -0,0 +1,13 @@ +# Config file for internal CI + +# Location of the continuous shell script in repository. +build_file: "grpc-go/test/kokoro/xds_url_map.sh" +timeout_mins: 60 + +action { + define_artifacts { + regex: "artifacts/**/*sponge_log.xml" + regex: "artifacts/**/*sponge_log.log" + strip_prefix: "artifacts" + } +} diff --git a/test/kokoro/xds_url_map.sh b/test/kokoro/xds_url_map.sh new file mode 100755 index 000000000000..f30fd1a15b57 --- /dev/null +++ b/test/kokoro/xds_url_map.sh @@ -0,0 +1,134 @@ +#!/usr/bin/env bash +# Copyright 2021 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -eo pipefail + +# Constants +readonly GITHUB_REPOSITORY_NAME="grpc-go" +# GKE Cluster +readonly GKE_CLUSTER_NAME="interop-test-psm-sec-v2-us-central1-a" +readonly GKE_CLUSTER_ZONE="us-central1-a" +## xDS test client Docker images +readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/go-client" +readonly FORCE_IMAGE_BUILD="${FORCE_IMAGE_BUILD:-0}" + +####################################### +# Builds test app Docker images and pushes them to GCR +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# None +# Outputs: +# Writes the output of `gcloud builds submit` to stdout, stderr +####################################### +build_test_app_docker_images() { + echo "Building Go xDS interop test app Docker images" + docker build -f "${SRC_DIR}/interop/xds/client/Dockerfile" -t "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" "${SRC_DIR}" + gcloud -q auth configure-docker + docker push "${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" +} + +####################################### +# Builds test app and its docker images unless they already exist +# Globals: +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# FORCE_IMAGE_BUILD +# Arguments: +# None +# Outputs: +# Writes the output to stdout, stderr +####################################### +build_docker_images_if_needed() { + # Check if images already exist + client_tags="$(gcloud_gcr_list_image_tags "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}")" + printf "Client image: %s:%s\n" "${CLIENT_IMAGE_NAME}" "${GIT_COMMIT}" + echo "${client_tags:-Client image not found}" + + # Build if any of the images are missing, or FORCE_IMAGE_BUILD=1 + if [[ "${FORCE_IMAGE_BUILD}" == "1" || -z "${client_tags}" ]]; then + build_test_app_docker_images + else + echo "Skipping Go test app build" + fi +} + +####################################### +# Executes the test case +# Globals: +# TEST_DRIVER_FLAGFILE: Relative path to test driver flagfile +# KUBE_CONTEXT: The name of kubectl context with GKE cluster access +# TEST_XML_OUTPUT_DIR: Output directory for the test xUnit XML report +# CLIENT_IMAGE_NAME: Test client Docker image name +# GIT_COMMIT: SHA-1 of git commit being built +# Arguments: +# Test case name +# Outputs: +# Writes the output of test execution to stdout, stderr +# Test xUnit report to ${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml +####################################### +run_test() { + # Test driver usage: + # https://github.com/grpc/grpc/tree/master/tools/run_tests/xds_k8s_test_driver#basic-usage + local test_name="${1:?Usage: run_test test_name}" + set -x + python -m "tests.${test_name}" \ + --flagfile="${TEST_DRIVER_FLAGFILE}" \ + --kube_context="${KUBE_CONTEXT}" \ + --client_image="${CLIENT_IMAGE_NAME}:${GIT_COMMIT}" \ + --xml_output_file="${TEST_XML_OUTPUT_DIR}/${test_name}/sponge_log.xml" \ + --flagfile="config/url-map.cfg" + set +x +} + +####################################### +# Main function: provision software necessary to execute tests, and run them +# Globals: +# KOKORO_ARTIFACTS_DIR +# GITHUB_REPOSITORY_NAME +# SRC_DIR: Populated with absolute path to the source repo +# TEST_DRIVER_REPO_DIR: Populated with the path to the repo containing +# the test driver +# TEST_DRIVER_FULL_DIR: Populated with the path to the test driver source code +# TEST_DRIVER_FLAGFILE: Populated with relative path to test driver flagfile +# TEST_XML_OUTPUT_DIR: Populated with the path to test xUnit XML report +# GIT_ORIGIN_URL: Populated with the origin URL of git repo used for the build +# GIT_COMMIT: Populated with the SHA-1 of git commit being built +# GIT_COMMIT_SHORT: Populated with the short SHA-1 of git commit being built +# KUBE_CONTEXT: Populated with name of kubectl context with GKE cluster access +# Arguments: +# None +# Outputs: +# Writes the output of test execution to stdout, stderr +####################################### +main() { + local script_dir + script_dir="$(dirname "$0")" + # shellcheck source=test/kokoro/xds_k8s_install_test_driver.sh + source "${script_dir}/xds_k8s_install_test_driver.sh" + set -x + if [[ -n "${KOKORO_ARTIFACTS_DIR}" ]]; then + kokoro_setup_test_driver "${GITHUB_REPOSITORY_NAME}" + else + local_setup_test_driver "${script_dir}" + fi + build_docker_images_if_needed + # Run tests + cd "${TEST_DRIVER_FULL_DIR}" + run_test url_map +} + +main "$@" diff --git a/version.go b/version.go index 4e26aec6ac18..9717a560b9ba 100644 --- a/version.go +++ b/version.go @@ -19,4 +19,4 @@ package grpc // Version is the current grpc version. -const Version = "1.39.0-dev" +const Version = "1.41.0-dev" diff --git a/vet.sh b/vet.sh index 1a0dbd7ee5ab..5eaa8b05d6d3 100755 --- a/vet.sh +++ b/vet.sh @@ -32,26 +32,14 @@ PATH="${HOME}/go/bin:${GOROOT}/bin:${PATH}" go version if [[ "$1" = "-install" ]]; then - # Check for module support - if go help mod >& /dev/null; then - # Install the pinned versions as defined in module tools. - pushd ./test/tools - go install \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - popd - else - # Ye olde `go get` incantation. - # Note: this gets the latest version of all tools (vs. the pinned versions - # with Go modules). - go get -u \ - golang.org/x/lint/golint \ - golang.org/x/tools/cmd/goimports \ - honnef.co/go/tools/cmd/staticcheck \ - github.com/client9/misspell/cmd/misspell - fi + # Install the pinned versions as defined in module tools. + pushd ./test/tools + go install \ + golang.org/x/lint/golint \ + golang.org/x/tools/cmd/goimports \ + honnef.co/go/tools/cmd/staticcheck \ + github.com/client9/misspell/cmd/misspell + popd if [[ -z "${VET_SKIP_PROTO}" ]]; then if [[ "${TRAVIS}" = "true" ]]; then PROTOBUF_VERSION=3.14.0 diff --git a/xds/csds/csds.go b/xds/csds/csds.go index 73b92e9443ce..1b54a3a4c6e3 100644 --- a/xds/csds/csds.go +++ b/xds/csds/csds.go @@ -25,7 +25,6 @@ package csds import ( "context" - "fmt" "io" "time" @@ -38,49 +37,36 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/types/known/timestamppb" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register v2 xds_client. - _ "google.golang.org/grpc/xds/internal/client/v3" // Register v3 xds_client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register v2 xds_client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register v3 xds_client. ) -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClientInterface interface { - DumpLDS() (string, map[string]client.UpdateWithMD) - DumpRDS() (string, map[string]client.UpdateWithMD) - DumpCDS() (string, map[string]client.UpdateWithMD) - DumpEDS() (string, map[string]client.UpdateWithMD) - BootstrapConfig() *bootstrap.Config - Close() -} - var ( logger = grpclog.Component("xds") - newXDSClient = func() (xdsClientInterface, error) { - return client.New() + newXDSClient = func() xdsclient.XDSClient { + c, err := xdsclient.New() + if err != nil { + logger.Warningf("failed to create xds client: %v", err) + return nil + } + return c } ) // ClientStatusDiscoveryServer implementations interface ClientStatusDiscoveryServiceServer. type ClientStatusDiscoveryServer struct { - // xdsClient will always be the same in practise. But we keep a copy in each + // xdsClient will always be the same in practice. But we keep a copy in each // server instance for testing. - xdsClient xdsClientInterface + xdsClient xdsclient.XDSClient } // NewClientStatusDiscoveryServer returns an implementation of the CSDS server that can be // registered on a gRPC server. func NewClientStatusDiscoveryServer() (*ClientStatusDiscoveryServer, error) { - xdsC, err := newXDSClient() - if err != nil { - return nil, fmt.Errorf("failed to create xds client: %v", err) - } - return &ClientStatusDiscoveryServer{ - xdsClient: xdsC, - }, nil + return &ClientStatusDiscoveryServer{xdsClient: newXDSClient()}, nil } // StreamClientStatus implementations interface ClientStatusDiscoveryServiceServer. @@ -109,10 +95,13 @@ func (s *ClientStatusDiscoveryServer) FetchClientStatus(_ context.Context, req * } // buildClientStatusRespForReq fetches the status from the client, and returns -// the response to be sent back to client. +// the response to be sent back to xdsclient. // // If it returns an error, the error is a status error. func (s *ClientStatusDiscoveryServer) buildClientStatusRespForReq(req *v3statuspb.ClientStatusRequest) (*v3statuspb.ClientStatusResponse, error) { + if s.xdsClient == nil { + return &v3statuspb.ClientStatusResponse{}, nil + } // Field NodeMatchers is unsupported, by design // https://github.com/grpc/proposal/blob/master/A40-csds-support.md#detail-node-matching. if len(req.NodeMatchers) != 0 { @@ -137,7 +126,9 @@ func (s *ClientStatusDiscoveryServer) buildClientStatusRespForReq(req *v3statusp // Close cleans up the resources. func (s *ClientStatusDiscoveryServer) Close() { - s.xdsClient.Close() + if s.xdsClient != nil { + s.xdsClient.Close() + } } // nodeProtoToV3 converts the given proto into a v3.Node. n is from bootstrap @@ -296,17 +287,17 @@ func (s *ClientStatusDiscoveryServer) buildEDSPerXDSConfig() *v3statuspb.PerXdsC } } -func serviceStatusToProto(serviceStatus client.ServiceStatus) v3adminpb.ClientResourceStatus { +func serviceStatusToProto(serviceStatus xdsclient.ServiceStatus) v3adminpb.ClientResourceStatus { switch serviceStatus { - case client.ServiceStatusUnknown: + case xdsclient.ServiceStatusUnknown: return v3adminpb.ClientResourceStatus_UNKNOWN - case client.ServiceStatusRequested: + case xdsclient.ServiceStatusRequested: return v3adminpb.ClientResourceStatus_REQUESTED - case client.ServiceStatusNotExist: + case xdsclient.ServiceStatusNotExist: return v3adminpb.ClientResourceStatus_DOES_NOT_EXIST - case client.ServiceStatusACKed: + case xdsclient.ServiceStatusACKed: return v3adminpb.ClientResourceStatus_ACKED - case client.ServiceStatusNACKed: + case xdsclient.ServiceStatusNACKed: return v3adminpb.ClientResourceStatus_NACKED default: return v3adminpb.ClientResourceStatus_UNKNOWN diff --git a/xds/csds/csds_test.go b/xds/csds/csds_test.go index 6cf88f6d3942..98dc93e86713 100644 --- a/xds/csds/csds_test.go +++ b/xds/csds/csds_test.go @@ -36,10 +36,10 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/xds" - "google.golang.org/grpc/xds/internal/client" _ "google.golang.org/grpc/xds/internal/httpfilter/router" xtestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/e2e" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/timestamppb" @@ -59,13 +59,6 @@ const ( defaultTestTimeout = 10 * time.Second ) -type xdsClientInterfaceWithWatch interface { - WatchListener(string, func(client.ListenerUpdate, error)) func() - WatchRouteConfig(string, func(client.RouteConfigUpdate, error)) func() - WatchCluster(string, func(client.ClusterUpdate, error)) func() - WatchEndpoints(string, func(client.EndpointsUpdate, error)) func() -} - var cmpOpts = cmp.Options{ cmpopts.EquateEmpty(), cmp.Comparer(func(a, b *timestamppb.Timestamp) bool { return true }), @@ -174,16 +167,16 @@ func TestCSDS(t *testing.T) { defer cleanup() for _, target := range ldsTargets { - xdsC.WatchListener(target, func(client.ListenerUpdate, error) {}) + xdsC.WatchListener(target, func(xdsclient.ListenerUpdate, error) {}) } for _, target := range rdsTargets { - xdsC.WatchRouteConfig(target, func(client.RouteConfigUpdate, error) {}) + xdsC.WatchRouteConfig(target, func(xdsclient.RouteConfigUpdate, error) {}) } for _, target := range cdsTargets { - xdsC.WatchCluster(target, func(client.ClusterUpdate, error) {}) + xdsC.WatchCluster(target, func(xdsclient.ClusterUpdate, error) {}) } for _, target := range edsTargets { - xdsC.WatchEndpoints(target, func(client.EndpointsUpdate, error) {}) + xdsC.WatchEndpoints(target, func(xdsclient.EndpointsUpdate, error) {}) } for i := 0; i < retryCount; i++ { @@ -250,7 +243,7 @@ func TestCSDS(t *testing.T) { } } -func commonSetup(t *testing.T) (xdsClientInterfaceWithWatch, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { +func commonSetup(t *testing.T) (xdsclient.XDSClient, *e2e.ManagementServer, string, v3statuspbgrpc.ClientStatusDiscoveryService_StreamClientStatusClient, func()) { t.Helper() // Spin up a xDS management server on a local port. @@ -270,14 +263,12 @@ func commonSetup(t *testing.T) (xdsClientInterfaceWithWatch, *e2e.ManagementServ t.Fatal(err) } // Create xds_client. - xdsC, err := client.New() + xdsC, err := xdsclient.New() if err != nil { t.Fatalf("failed to create xds client: %v", err) } oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { - return xdsC, nil - } + newXDSClient = func() xdsclient.XDSClient { return xdsC } // Initialize an gRPC server and register CSDS on it. server := grpc.NewServer() @@ -635,6 +626,58 @@ func protoToJSON(p proto.Message) string { return ret } +func TestCSDSNoXDSClient(t *testing.T) { + oldNewXDSClient := newXDSClient + newXDSClient = func() xdsclient.XDSClient { return nil } + defer func() { newXDSClient = oldNewXDSClient }() + + // Initialize an gRPC server and register CSDS on it. + server := grpc.NewServer() + csdss, err := NewClientStatusDiscoveryServer() + if err != nil { + t.Fatal(err) + } + defer csdss.Close() + v3statuspbgrpc.RegisterClientStatusDiscoveryServiceServer(server, csdss) + // Create a local listener and pass it to Serve(). + lis, err := xtestutils.LocalTCPListener() + if err != nil { + t.Fatalf("xtestutils.LocalTCPListener() failed: %v", err) + } + go func() { + if err := server.Serve(lis); err != nil { + t.Errorf("Serve() failed: %v", err) + } + }() + defer server.Stop() + + // Create CSDS client. + conn, err := grpc.Dial(lis.Addr().String(), grpc.WithInsecure()) + if err != nil { + t.Fatalf("cannot connect to server: %v", err) + } + defer conn.Close() + c := v3statuspbgrpc.NewClientStatusDiscoveryServiceClient(conn) + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + stream, err := c.StreamClientStatus(ctx, grpc.WaitForReady(true)) + if err != nil { + t.Fatalf("cannot get ServerReflectionInfo: %v", err) + } + + if err := stream.Send(&v3statuspb.ClientStatusRequest{Node: nil}); err != nil { + t.Fatalf("failed to send: %v", err) + } + r, err := stream.Recv() + if err != nil { + // io.EOF is not ok. + t.Fatalf("failed to recv response: %v", err) + } + if n := len(r.Config); n != 0 { + t.Fatalf("got %d configs, want 0: %v", n, proto.MarshalTextString(r)) + } +} + func Test_nodeProtoToV3(t *testing.T) { const ( testID = "test-id" diff --git a/xds/googledirectpath/googlec2p.go b/xds/googledirectpath/googlec2p.go index 4ccec4ec4120..0c2f984fbcb1 100644 --- a/xds/googledirectpath/googlec2p.go +++ b/xds/googledirectpath/googlec2p.go @@ -35,12 +35,13 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/internal/googlecloud" internalgrpclog "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcrand" "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/resolver" _ "google.golang.org/grpc/xds" // To register xds resolvers and balancers. - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/types/known/structpb" ) @@ -61,15 +62,11 @@ const ( dnsName, xdsName = "dns", "xds" ) -type xdsClientInterface interface { - Close() -} - // For overriding in unittests. var ( onGCE = googlecloud.OnGCE - newClientWithConfig = func(config *bootstrap.Config) (xdsClientInterface, error) { + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { return xdsclient.NewWithConfig(config) } @@ -138,7 +135,7 @@ func (c2pResolverBuilder) Scheme() string { type c2pResolver struct { resolver.Resolver - client xdsClientInterface + client xdsclient.XDSClient } func (r *c2pResolver) Close() { @@ -152,13 +149,15 @@ var ipv6EnabledMetadata = &structpb.Struct{ }, } +var id = fmt.Sprintf("C2P-%d", grpcrand.Int()) + // newNode makes a copy of defaultNode, and populate it's Metadata and // Locality fields. func newNode(zone string, ipv6Capable bool) *v3corepb.Node { ret := &v3corepb.Node{ // Not all required fields are set in defaultNote. Metadata will be set // if ipv6 is enabled. Locality will be set to the value from metadata. - Id: "C2P", + Id: id, UserAgentName: gRPCUserAgentName, UserAgentVersionType: &v3corepb.Node_UserAgentVersion{UserAgentVersion: grpc.Version}, ClientFeatures: []string{clientFeatureNoOverprovisioning}, diff --git a/xds/googledirectpath/googlec2p_test.go b/xds/googledirectpath/googlec2p_test.go index 524bb82e0f39..8f98d3159d3a 100644 --- a/xds/googledirectpath/googlec2p_test.go +++ b/xds/googledirectpath/googlec2p_test.go @@ -31,8 +31,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/structpb" ) @@ -130,6 +131,7 @@ func TestBuildNotOnGCE(t *testing.T) { } type testXDSClient struct { + xdsclient.XDSClient closed chan struct{} } @@ -177,7 +179,7 @@ func TestBuildXDS(t *testing.T) { configCh := make(chan *bootstrap.Config, 1) oldNewClient := newClientWithConfig - newClientWithConfig = func(config *bootstrap.Config) (xdsClientInterface, error) { + newClientWithConfig = func(config *bootstrap.Config) (xdsclient.XDSClient, error) { configCh <- config return tXDSClient, nil } @@ -194,7 +196,7 @@ func TestBuildXDS(t *testing.T) { } wantNode := &v3corepb.Node{ - Id: "C2P", + Id: id, Metadata: nil, Locality: &v3corepb.Locality{Zone: testZone}, UserAgentName: gRPCUserAgentName, diff --git a/xds/internal/balancer/balancer.go b/xds/internal/balancer/balancer.go index 5883027a2c52..86656736a61b 100644 --- a/xds/internal/balancer/balancer.go +++ b/xds/internal/balancer/balancer.go @@ -20,8 +20,10 @@ package balancer import ( - _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer - _ "google.golang.org/grpc/xds/internal/balancer/edsbalancer" // Register the EDS balancer - _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer + _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // Register the CDS balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterimpl" // Register the xds_cluster_impl balancer + _ "google.golang.org/grpc/xds/internal/balancer/clustermanager" // Register the xds_cluster_manager balancer + _ "google.golang.org/grpc/xds/internal/balancer/clusterresolver" // Register the xds_cluster_resolver balancer + _ "google.golang.org/grpc/xds/internal/balancer/priority" // Register the priority balancer + _ "google.golang.org/grpc/xds/internal/balancer/weightedtarget" // Register the weighted_target balancer ) diff --git a/xds/internal/balancer/balancergroup/balancergroup.go b/xds/internal/balancer/balancergroup/balancergroup.go index 5b6d42a25e44..6d54728dc919 100644 --- a/xds/internal/balancer/balancergroup/balancergroup.go +++ b/xds/internal/balancer/balancergroup/balancergroup.go @@ -24,7 +24,7 @@ import ( "time" orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" @@ -183,7 +183,7 @@ type BalancerGroup struct { cc balancer.ClientConn buildOpts balancer.BuildOptions logger *grpclog.PrefixLogger - loadStore load.PerClusterReporter + loadStore load.PerClusterReporter // TODO: delete this, no longer needed. It was used by EDS. // stateAggregator is where the state/picker updates will be sent to. It's // provided by the parent balancer, to build a picker with all the diff --git a/xds/internal/balancer/balancergroup/balancergroup_test.go b/xds/internal/balancer/balancergroup/balancergroup_test.go index 1ba9195ab1d0..355810a52004 100644 --- a/xds/internal/balancer/balancergroup/balancergroup_test.go +++ b/xds/internal/balancer/balancergroup/balancergroup_test.go @@ -44,8 +44,8 @@ import ( "google.golang.org/grpc/internal/balancer/stub" "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/internal/balancer/weightedtarget/weightedaggregator" - "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) var ( @@ -845,7 +845,7 @@ func (s) TestBalancerGroup_locality_caching_not_readd_within_timeout(t *testing. defer replaceDefaultSubBalancerCloseTimeout(time.Second)() _, _, cc, addrToSC := initBalancerGroupForCachingTest(t) - // The sub-balancer is not re-added withtin timeout. The subconns should be + // The sub-balancer is not re-added within timeout. The subconns should be // removed. removeTimeout := time.After(DefaultSubBalancerCloseTimeout) scToRemove := map[balancer.SubConn]int{ diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer.go b/xds/internal/balancer/cdsbalancer/cdsbalancer.go index 0e8b83481ac9..3ea14add9ebf 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer.go @@ -34,65 +34,52 @@ import ( "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( cdsName = "cds_experimental" - edsName = "eds_experimental" ) var ( errBalancerClosed = errors.New("cdsBalancer is closed") - // newEDSBalancer is a helper function to build a new edsBalancer and will be - // overridden in unittests. - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { - builder := balancer.Get(edsName) + // newChildBalancer is a helper function to build a new cluster_resolver + // balancer and will be overridden in unittests. + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + builder := balancer.Get(clusterresolver.Name) if builder == nil { - return nil, fmt.Errorf("xds: no balancer builder with name %v", edsName) + return nil, fmt.Errorf("xds: no balancer builder with name %v", clusterresolver.Name) } - // We directly pass the parent clientConn to the - // underlying edsBalancer because the cdsBalancer does - // not deal with subConns. + // We directly pass the parent clientConn to the underlying + // cluster_resolver balancer because the cdsBalancer does not deal with + // subConns. return builder.Build(cc, opts), nil } - newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } buildProvider = buildProviderFunc ) func init() { - balancer.Register(cdsBB{}) + balancer.Register(bb{}) } -// cdsBB (short for cdsBalancerBuilder) implements the balancer.Builder -// interface to help build a cdsBalancer. +// bb implements the balancer.Builder interface to help build a cdsBalancer. // It also implements the balancer.ConfigParser interface to help parse the // JSON service config, to be passed to the cdsBalancer. -type cdsBB struct{} +type bb struct{} // Build creates a new CDS balancer with the ClientConn. -func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &cdsBalancer{ - bOpts: opts, - updateCh: buffer.NewUnbounded(), - closed: grpcsync.NewEvent(), - done: grpcsync.NewEvent(), - cancelWatch: func() {}, // No-op at this point. - xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + xdsHI: xdsinternal.NewHandshakeInfo(nil, nil), } b.logger = prefixLogger((b)) b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsClient = client - var creds credentials.TransportCredentials switch { case opts.DialCreds != nil: @@ -104,7 +91,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. b.xdsCredsInUse = true } b.logger.Infof("xDS credentials in use: %v", b.xdsCredsInUse) - + b.clusterHandler = newClusterHandler(b) b.ccw = &ccWrapper{ ClientConn: cc, xdsHI: b.xdsHI, @@ -114,7 +101,7 @@ func (cdsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer. } // Name returns the name of balancers built by this builder. -func (cdsBB) Name() string { +func (bb) Name() string { return cdsName } @@ -127,7 +114,7 @@ type lbConfig struct { // ParseConfig parses the JSON load balancer config provided into an // internal form or returns an error if the config is invalid. -func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { var cfg lbConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, fmt.Errorf("xds: unable to unmarshal lbconfig: %s, error: %v", string(c), err) @@ -135,57 +122,35 @@ func (cdsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, return &cfg, nil } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the cdsBalancer. This will be faked out in unittests. -type xdsClientInterface interface { - WatchCluster(string, func(xdsclient.ClusterUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // ccUpdate wraps a clientConn update received from gRPC (pushed from the // xdsResolver). A valid clusterName causes the cdsBalancer to register a CDS // watcher with the xdsClient, while a non-nil error causes it to cancel the -// existing watch and propagate the error to the underlying edsBalancer. +// existing watch and propagate the error to the underlying cluster_resolver +// balancer. type ccUpdate struct { clusterName string err error } -type clusterHandlerUpdate struct { - chu []xdsclient.ClusterUpdate - err error -} - // scUpdate wraps a subConn update received from gRPC. This is directly passed -// on to the edsBalancer. +// on to the cluster_resolver balancer. type scUpdate struct { subConn balancer.SubConn state balancer.SubConnState } -// watchUpdate wraps the information received from a registered CDS watcher. A -// non-nil error is propagated to the underlying edsBalancer. A valid update -// results in creating a new edsBalancer (if one doesn't already exist) and -// pushing the update to it. -type watchUpdate struct { - cds xdsclient.ClusterUpdate - err error -} - -// cdsBalancer implements a CDS based LB policy. It instantiates an EDS based -// LB policy to further resolve the serviceName received from CDS, into -// localities and endpoints. Implements the balancer.Balancer interface which -// is exposed to gRPC and implements the balancer.ClientConn interface which is -// exposed to the edsBalancer. +// cdsBalancer implements a CDS based LB policy. It instantiates a +// cluster_resolver balancer to further resolve the serviceName received from +// CDS, into localities and endpoints. Implements the balancer.Balancer +// interface which is exposed to gRPC and implements the balancer.ClientConn +// interface which is exposed to the cluster_resolver balancer. type cdsBalancer struct { ccw *ccWrapper // ClientConn interface passed to child LB. bOpts balancer.BuildOptions // BuildOptions passed to child LB. updateCh *buffer.Unbounded // Channel for gRPC and xdsClient updates. - xdsClient xdsClientInterface // xDS client to watch Cluster resource. - cancelWatch func() // Cluster watch cancel func. - edsLB balancer.Balancer // EDS child policy. - clusterToWatch string + xdsClient xdsclient.XDSClient // xDS client to watch Cluster resource. + clusterHandler *clusterHandler // To watch the clusters. + childLB balancer.Balancer logger *grpclog.PrefixLogger closed *grpcsync.Event done *grpcsync.Event @@ -201,25 +166,15 @@ type cdsBalancer struct { // handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good // updates lead to registration of a CDS watch. Updates with error lead to // cancellation of existing watch and propagation of the same error to the -// edsBalancer. +// cluster_resolver balancer. func (b *cdsBalancer) handleClientConnUpdate(update *ccUpdate) { // We first handle errors, if any, and then proceed with handling the // update, only if the status quo has changed. if err := update.err; err != nil { b.handleErrorFromUpdate(err, true) - } - if b.clusterToWatch == update.clusterName { return } - if update.clusterName != "" { - cancelWatch := b.xdsClient.WatchCluster(update.clusterName, b.handleClusterUpdate) - b.logger.Infof("Watch started on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - b.cancelWatch = func() { - cancelWatch() - b.logger.Infof("Watch cancelled on resource name %v with xds-client %p", update.clusterName, b.xdsClient) - } - b.clusterToWatch = update.clusterName - } + b.clusterHandler.updateRootCluster(update.clusterName) } // handleSecurityConfig processes the security configuration received from the @@ -311,22 +266,22 @@ func buildProviderFunc(configs map[string]*certprovider.BuildableConfig, instanc } // handleWatchUpdate handles a watch update from the xDS Client. Good updates -// lead to clientConn updates being invoked on the underlying edsBalancer. -func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { +// lead to clientConn updates being invoked on the underlying cluster_resolver balancer. +func (b *cdsBalancer) handleWatchUpdate(update clusterHandlerUpdate) { if err := update.err; err != nil { b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) b.handleErrorFromUpdate(err, false) return } - b.logger.Infof("Watch update from xds-client %p, content: %+v", b.xdsClient, pretty.ToJSON(update.cds)) + b.logger.Infof("Watch update from xds-client %p, content: %+v, security config: %v", b.xdsClient, pretty.ToJSON(update.updates), pretty.ToJSON(update.securityCfg)) // Process the security config from the received update before building the // child policy or forwarding the update to it. We do this because the child // policy may try to create a new subConn inline. Processing the security // configuration here and setting up the handshakeInfo will make sure that // such attempts are handled properly. - if err := b.handleSecurityConfig(update.cds.SecurityCfg); err != nil { + if err := b.handleSecurityConfig(update.securityCfg); err != nil { // If the security config is invalid, for example, if the provider // instance is not found in the bootstrap config, we need to put the // channel in transient failure. @@ -336,34 +291,54 @@ func (b *cdsBalancer) handleWatchUpdate(update *watchUpdate) { } // The first good update from the watch API leads to the instantiation of an - // edsBalancer. Further updates/errors are propagated to the existing - // edsBalancer. - if b.edsLB == nil { - edsLB, err := newEDSBalancer(b.ccw, b.bOpts) + // cluster_resolver balancer. Further updates/errors are propagated to the existing + // cluster_resolver balancer. + if b.childLB == nil { + childLB, err := newChildBalancer(b.ccw, b.bOpts) if err != nil { - b.logger.Errorf("Failed to create child policy of type %s, %v", edsName, err) + b.logger.Errorf("Failed to create child policy of type %s, %v", clusterresolver.Name, err) return } - b.edsLB = edsLB - b.logger.Infof("Created child policy %p of type %s", b.edsLB, edsName) + b.childLB = childLB + b.logger.Infof("Created child policy %p of type %s", b.childLB, clusterresolver.Name) + } + + dms := make([]clusterresolver.DiscoveryMechanism, len(update.updates)) + for i, cu := range update.updates { + switch cu.ClusterType { + case xdsclient.ClusterTypeEDS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: cu.ClusterName, + EDSServiceName: cu.EDSServiceName, + MaxConcurrentRequests: cu.MaxRequests, + } + if cu.EnableLRS { + // An empty string here indicates that the cluster_resolver balancer should use the + // same xDS server for load reporting as it does for EDS + // requests/responses. + dms[i].LoadReportingServerName = new(string) + + } + case xdsclient.ClusterTypeLogicalDNS: + dms[i] = clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeLogicalDNS, + DNSHostname: cu.DNSHostName, + } + default: + b.logger.Infof("unexpected cluster type %v when handling update from cluster handler", cu.ClusterType) + } } - lbCfg := &edsbalancer.EDSConfig{ - ClusterName: update.cds.ClusterName, - EDSServiceName: update.cds.EDSServiceName, - MaxConcurrentRequests: update.cds.MaxRequests, + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: dms, } - if update.cds.EnableLRS { - // An empty string here indicates that the edsBalancer should use the - // same xDS server for load reporting as it does for EDS - // requests/responses. - lbCfg.LrsLoadReportingServerName = new(string) - } ccState := balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, b.xdsClient), BalancerConfig: lbCfg, } - if err := b.edsLB.UpdateClientConnState(ccState); err != nil { - b.logger.Errorf("xds: edsBalancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) + if err := b.childLB.UpdateClientConnState(ccState); err != nil { + b.logger.Errorf("xds: cluster_resolver balancer.UpdateClientConnState(%+v) returned error: %v", ccState, err) } } @@ -380,24 +355,21 @@ func (b *cdsBalancer) run() { b.handleClientConnUpdate(update) case *scUpdate: // SubConn updates are passthrough and are simply handed over to - // the underlying edsBalancer. - if b.edsLB == nil { - b.logger.Errorf("xds: received scUpdate {%+v} with no edsBalancer", update) + // the underlying cluster_resolver balancer. + if b.childLB == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no cluster_resolver balancer", update) break } - b.edsLB.UpdateSubConnState(update.subConn, update.state) - case *watchUpdate: - b.handleWatchUpdate(update) + b.childLB.UpdateSubConnState(update.subConn, update.state) } + case u := <-b.clusterHandler.updateChannel: + b.handleWatchUpdate(u) case <-b.closed.Done(): - b.cancelWatch() - b.cancelWatch = func() {} - - if b.edsLB != nil { - b.edsLB.Close() - b.edsLB = nil + b.clusterHandler.close() + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil } - b.xdsClient.Close() if b.cachedRoot != nil { b.cachedRoot.Close() } @@ -424,23 +396,22 @@ func (b *cdsBalancer) run() { // - If it's from xds client, it means CDS resource were removed. The CDS // watcher should keep watching. // -// In both cases, the error will be forwarded to EDS balancer. And if error is -// resource-not-found, the child EDS balancer will stop watching EDS. +// In both cases, the error will be forwarded to the child balancer. And if +// error is resource-not-found, the child balancer will stop watching EDS. func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { - // TODO: connection errors will be sent to the eds balancers directly, and - // also forwarded by the parent balancers/resolvers. So the eds balancer may - // see the same error multiple times. We way want to only forward the error - // to eds if it's not a connection error. - // // This is not necessary today, because xds client never sends connection // errors. if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { - b.cancelWatch() + b.clusterHandler.close() } - if b.edsLB != nil { - b.edsLB.ResolverError(err) + if b.childLB != nil { + if xdsclient.ErrType(err) != xdsclient.ErrorTypeConnection { + // Connection errors will be sent to the child balancers directly. + // There's no need to forward them. + b.childLB.ResolverError(err) + } } else { - // If eds balancer was never created, fail the RPCs with + // If child balancer was never created, fail the RPCs with // errors. b.ccw.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, @@ -449,16 +420,6 @@ func (b *cdsBalancer) handleErrorFromUpdate(err error, fromParent bool) { } } -// handleClusterUpdate is the CDS watch API callback. It simply pushes the -// received information on to the update channel for run() to pick it up. -func (b *cdsBalancer) handleClusterUpdate(cu xdsclient.ClusterUpdate, err error) { - if b.closed.HasFired() { - b.logger.Warningf("xds: received cluster update {%+v} after cdsBalancer was closed", cu) - return - } - b.updateCh.Put(&watchUpdate{cds: cu, err: err}) -} - // UpdateClientConnState receives the serviceConfig (which contains the // clusterName to watch for in CDS) and the xdsClient object from the // xdsResolver. @@ -468,6 +429,14 @@ func (b *cdsBalancer) UpdateClientConnState(state balancer.ClientConnState) erro return errBalancerClosed } + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(state.BalancerConfig)) // The errors checked here should ideally never happen because the // ServiceConfig in this case is prepared by the xdsResolver and is not diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go index 52b1a05f1362..7eb1d0889395 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_security_test.go @@ -36,10 +36,10 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -133,11 +133,7 @@ func (p *fakeProvider) Close() { // xDSCredentials. func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -157,14 +153,14 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS // Override the creation of the EDS balancer to return a fake EDS balancer // implementation. edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } // Push a ClientConnState update to the CDS balancer with a cluster name. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -181,8 +177,8 @@ func setupWithXDSCreds(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDS } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newXDSClient = oldNewXDSClient - newEDSBalancer = oldEDSBalancerBuilder + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -255,7 +251,7 @@ func (s) TestSecurityConfigWithoutXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -310,7 +306,7 @@ func (s) TestNoSecurityConfigWithXDSCreds(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. No security config is + // newChildBalancer function as part of test setup. No security config is // passed to the CDS balancer as part of this update. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) @@ -468,7 +464,7 @@ func (s) TestSecurityConfigUpdate_BadToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. wantCCS := edsCCS(serviceName, nil, false) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdateWithGoodSecurityCfg, nil}, wantCCS, edsB); err != nil { t.Fatal(err) @@ -502,7 +498,7 @@ func (s) TestGoodSecurityConfig(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() @@ -555,7 +551,7 @@ func (s) TestSecurityConfigUpdate_GoodToFallback(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() @@ -605,7 +601,7 @@ func (s) TestSecurityConfigUpdate_GoodToBad(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() @@ -676,7 +672,7 @@ func (s) TestSecurityConfigUpdate_GoodToGood(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ ClusterName: serviceName, SecurityCfg: &xdsclient.SecurityConfig{ diff --git a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go index fba3ee1531e9..864af36857bc 100644 --- a/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go +++ b/xds/internal/balancer/cdsbalancer/cdsbalancer_test.go @@ -28,7 +28,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/attributes" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal" @@ -36,11 +35,10 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/edsbalancer" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/balancer/clusterresolver" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -130,7 +128,10 @@ func (tb *testEDSBalancer) waitForClientConnUpdate(ctx context.Context, wantCCS return err } gotCCS := ccs.(balancer.ClientConnState) - if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreUnexported(attributes.Attributes{})) { + if xdsclient.FromResolverState(gotCCS.ResolverState) == nil { + return fmt.Errorf("want resolver state with XDSClient attached, got one without") + } + if !cmp.Equal(gotCCS, wantCCS, cmpopts.IgnoreFields(resolver.State{}, "Attributes")) { return fmt.Errorf("received ClientConnState: %+v, want %+v", gotCCS, wantCCS) } return nil @@ -174,7 +175,7 @@ func (tb *testEDSBalancer) waitForClose(ctx context.Context) error { // cdsCCS is a helper function to construct a good update passed from the // xdsResolver to the cdsBalancer. -func cdsCCS(cluster string) balancer.ClientConnState { +func cdsCCS(cluster string, xdsC xdsclient.XDSClient) balancer.ClientConnState { const cdsLBConfig = `{ "loadBalancingConfig":[ { @@ -186,9 +187,9 @@ func cdsCCS(cluster string) balancer.ClientConnState { }` jsonSC := fmt.Sprintf(cdsLBConfig, cluster) return balancer.ClientConnState{ - ResolverState: resolver.State{ + ResolverState: xdsclient.SetClient(resolver.State{ ServiceConfig: internal.ParseServiceConfigForTesting.(func(string) *serviceconfig.ParseResult)(jsonSC), - }, + }, xdsC), BalancerConfig: &lbConfig{ClusterName: clusterName}, } } @@ -196,27 +197,29 @@ func cdsCCS(cluster string) balancer.ClientConnState { // edsCCS is a helper function to construct a good update passed from the // cdsBalancer to the edsBalancer. func edsCCS(service string, countMax *uint32, enableLRS bool) balancer.ClientConnState { - lbCfg := &edsbalancer.EDSConfig{ - ClusterName: service, + discoveryMechanism := clusterresolver.DiscoveryMechanism{ + Type: clusterresolver.DiscoveryMechanismTypeEDS, + Cluster: service, MaxConcurrentRequests: countMax, } if enableLRS { - lbCfg.LrsLoadReportingServerName = new(string) + discoveryMechanism.LoadReportingServerName = new(string) + } + lbCfg := &clusterresolver.LBConfig{ + DiscoveryMechanisms: []clusterresolver.DiscoveryMechanism{discoveryMechanism}, + } + return balancer.ClientConnState{ BalancerConfig: lbCfg, } } // setup creates a cdsBalancer and an edsBalancer (and overrides the -// newEDSBalancer function to return it), and also returns a cleanup function. +// newChildBalancer function to return it), and also returns a cleanup function. func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *xdstestutils.TestClientConn, func()) { t.Helper() - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - builder := balancer.Get(cdsName) if builder == nil { t.Fatalf("balancer.Get(%q) returned nil", cdsName) @@ -225,15 +228,15 @@ func setup(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBalancer, *x cdsB := builder.Build(tcc, balancer.BuildOptions{}) edsB := newTestEDSBalancer() - oldEDSBalancerBuilder := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { + oldEDSBalancerBuilder := newChildBalancer + newChildBalancer = func(cc balancer.ClientConn, opts balancer.BuildOptions) (balancer.Balancer, error) { edsB.parentCC = cc return edsB, nil } return xdsC, cdsB.(*cdsBalancer), edsB, tcc, func() { - newEDSBalancer = oldEDSBalancerBuilder - newXDSClient = oldNewXDSClient + newChildBalancer = oldEDSBalancerBuilder + xdsC.Close() } } @@ -243,7 +246,7 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal t.Helper() xdsC, cdsB, edsB, tcc, cancel := setup(t) - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } @@ -263,6 +266,9 @@ func setupWithWatch(t *testing.T) (*fakeclient.Client, *cdsBalancer, *testEDSBal // cdsBalancer with different inputs and verifies that the CDS watch API on the // provided xdsClient is invoked appropriately. func (s) TestUpdateClientConnState(t *testing.T) { + xdsC := fakeclient.NewClient() + defer xdsC.Close() + tests := []struct { name string ccs balancer.ClientConnState @@ -281,14 +287,14 @@ func (s) TestUpdateClientConnState(t *testing.T) { }, { name: "happy-good-case", - ccs: cdsCCS(clusterName), + ccs: cdsCCS(clusterName, xdsC), wantCluster: clusterName, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - xdsC, cdsB, _, _, cancel := setup(t) + _, cdsB, _, _, cancel := setup(t) defer func() { cancel() cdsB.Close() @@ -325,7 +331,7 @@ func (s) TestUpdateClientConnStateWithSameState(t *testing.T) { }() // This is the same clientConn update sent in setupWithWatch(). - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != nil { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != nil { t.Fatalf("cdsBalancer.UpdateClientConnState failed with error: %v", err) } // The above update should not result in a new watch being registered. @@ -426,7 +432,7 @@ func (s) TestHandleClusterUpdateError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -511,7 +517,7 @@ func (s) TestResolverError(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -560,7 +566,7 @@ func (s) TestUpdateSubConnState(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -596,8 +602,8 @@ func (s) TestCircuitBreaking(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will update // the service's counter with the new max requests. var maxRequests uint32 = 1 - cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName, MaxRequests: &maxRequests} - wantCCS := edsCCS(serviceName, &maxRequests, false) + cdsUpdate := xdsclient.ClusterUpdate{ClusterName: clusterName, MaxRequests: &maxRequests} + wantCCS := edsCCS(clusterName, &maxRequests, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer ctxCancel() if err := invokeWatchCbAndWait(ctx, xdsC, cdsWatchInfo{cdsUpdate, nil}, wantCCS, edsB); err != nil { @@ -606,7 +612,7 @@ func (s) TestCircuitBreaking(t *testing.T) { // Since the counter's max requests was set to 1, the first request should // succeed and the second should fail. - counter := client.GetServiceRequestsCounter(serviceName) + counter := xdsclient.GetClusterRequestsCounter(clusterName, "") if err := counter.StartRequest(maxRequests); err != nil { t.Fatal(err) } @@ -628,7 +634,7 @@ func (s) TestClose(t *testing.T) { // will trigger the watch handler on the CDS balancer, which will attempt to // create a new EDS balancer. The fake EDS balancer created above will be // returned to the CDS balancer, because we have overridden the - // newEDSBalancer function as part of test setup. + // newChildBalancer function as part of test setup. cdsUpdate := xdsclient.ClusterUpdate{ClusterName: serviceName} wantCCS := edsCCS(serviceName, nil, false) ctx, ctxCancel := context.WithTimeout(context.Background(), defaultTestTimeout) @@ -661,7 +667,7 @@ func (s) TestClose(t *testing.T) { // Make sure that the UpdateClientConnState() method on the CDS balancer // returns error. - if err := cdsB.UpdateClientConnState(cdsCCS(clusterName)); err != errBalancerClosed { + if err := cdsB.UpdateClientConnState(cdsCCS(clusterName, xdsC)); err != errBalancerClosed { t.Fatalf("UpdateClientConnState() after close returned %v, want %v", err, errBalancerClosed) } diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler.go b/xds/internal/balancer/cdsbalancer/cluster_handler.go index 4c13b55594da..1f5acafe110b 100644 --- a/xds/internal/balancer/cdsbalancer/cluster_handler.go +++ b/xds/internal/balancer/cdsbalancer/cluster_handler.go @@ -20,15 +20,29 @@ import ( "errors" "sync" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) var errNotReceivedUpdate = errors.New("tried to construct a cluster update on a cluster that has not received an update") +// clusterHandlerUpdate wraps the information received from the registered CDS +// watcher. A non-nil error is propagated to the underlying cluster_resolver +// balancer. A valid update results in creating a new cluster_resolver balancer +// (if one doesn't already exist) and pushing the update to it. +type clusterHandlerUpdate struct { + // securityCfg is the Security Config from the top (root) cluster. + securityCfg *xdsclient.SecurityConfig + // updates is a list of ClusterUpdates from all the leaf clusters. + updates []xdsclient.ClusterUpdate + err error +} + // clusterHandler will be given a name representing a cluster. It will then // update the CDS policy constantly with a list of Clusters to pass down to // XdsClusterResolverLoadBalancingPolicyConfig in a stream like fashion. type clusterHandler struct { + parent *cdsBalancer + // A mutex to protect entire tree of clusters. clusterMutex sync.Mutex root *clusterNode @@ -39,8 +53,13 @@ type clusterHandler struct { // update or from a child with an error. Capacity of one as the only update // CDS Balancer cares about is the most recent update. updateChannel chan clusterHandlerUpdate +} - xdsClient xdsClientInterface +func newClusterHandler(parent *cdsBalancer) *clusterHandler { + return &clusterHandler{ + parent: parent, + updateChannel: make(chan clusterHandlerUpdate, 1), + } } func (ch *clusterHandler) updateRootCluster(rootClusterName string) { @@ -48,7 +67,7 @@ func (ch *clusterHandler) updateRootCluster(rootClusterName string) { defer ch.clusterMutex.Unlock() if ch.root == nil { // Construct a root node on first update. - ch.root = createClusterNode(rootClusterName, ch.xdsClient, ch) + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) ch.rootClusterName = rootClusterName return } @@ -56,24 +75,33 @@ func (ch *clusterHandler) updateRootCluster(rootClusterName string) { // new one, if not do nothing. if rootClusterName != ch.rootClusterName { ch.root.delete() - ch.root = createClusterNode(rootClusterName, ch.xdsClient, ch) + ch.root = createClusterNode(rootClusterName, ch.parent.xdsClient, ch) ch.rootClusterName = rootClusterName } } // This function tries to construct a cluster update to send to CDS. func (ch *clusterHandler) constructClusterUpdate() { - // If there was an error received no op, as this simply means one of the - // children hasn't received an update yet. - if clusterUpdate, err := ch.root.constructClusterUpdate(); err == nil { - // For a ClusterUpdate, the only update CDS cares about is the most - // recent one, so opportunistically drain the update channel before - // sending the new update. - select { - case <-ch.updateChannel: - default: - } - ch.updateChannel <- clusterHandlerUpdate{chu: clusterUpdate, err: nil} + if ch.root == nil { + // If root is nil, this handler is closed, ignore the update. + return + } + clusterUpdate, err := ch.root.constructClusterUpdate() + if err != nil { + // If there was an error received no op, as this simply means one of the + // children hasn't received an update yet. + return + } + // For a ClusterUpdate, the only update CDS cares about is the most + // recent one, so opportunistically drain the update channel before + // sending the new update. + select { + case <-ch.updateChannel: + default: + } + ch.updateChannel <- clusterHandlerUpdate{ + securityCfg: ch.root.clusterUpdate.SecurityCfg, + updates: clusterUpdate, } } @@ -82,6 +110,9 @@ func (ch *clusterHandler) constructClusterUpdate() { func (ch *clusterHandler) close() { ch.clusterMutex.Lock() defer ch.clusterMutex.Unlock() + if ch.root == nil { + return + } ch.root.delete() ch.root = nil ch.rootClusterName = "" @@ -112,12 +143,17 @@ type clusterNode struct { // CreateClusterNode creates a cluster node from a given clusterName. This will // also start the watch for that cluster. -func createClusterNode(clusterName string, xdsClient xdsClientInterface, topLevelHandler *clusterHandler) *clusterNode { +func createClusterNode(clusterName string, xdsClient xdsclient.XDSClient, topLevelHandler *clusterHandler) *clusterNode { c := &clusterNode{ clusterHandler: topLevelHandler, } // Communicate with the xds client here. - c.cancelFunc = xdsClient.WatchCluster(clusterName, c.handleResp) + topLevelHandler.parent.logger.Infof("CDS watch started on %v", clusterName) + cancel := xdsClient.WatchCluster(clusterName, c.handleResp) + c.cancelFunc = func() { + topLevelHandler.parent.logger.Infof("CDS watch canceled on %v", clusterName) + cancel() + } return c } @@ -172,15 +208,10 @@ func (c *clusterNode) handleResp(clusterUpdate xdsclient.ClusterUpdate, err erro case <-c.clusterHandler.updateChannel: default: } - c.clusterHandler.updateChannel <- clusterHandlerUpdate{chu: nil, err: err} + c.clusterHandler.updateChannel <- clusterHandlerUpdate{err: err} return } - // deltaInClusterUpdateFields determines whether there was a delta in the - // clusterUpdate fields (forgetting the children). This will be used to help - // determine whether to pingClusterHandler at the end of this callback or - // not. - deltaInClusterUpdateFields := clusterUpdate.ClusterName != c.clusterUpdate.ClusterName || clusterUpdate.ClusterType != c.clusterUpdate.ClusterType c.receivedUpdate = true c.clusterUpdate = clusterUpdate @@ -195,9 +226,14 @@ func (c *clusterNode) handleResp(clusterUpdate xdsclient.ClusterUpdate, err erro child.delete() } c.children = nil - if deltaInClusterUpdateFields { - c.clusterHandler.constructClusterUpdate() - } + // This is an update in the one leaf node, should try to send an update + // to the parent CDS balancer. + // + // Note that this update might be a duplicate from the previous one. + // Because the update contains not only the cluster name to watch, but + // also the extra fields (e.g. security config). There's no good way to + // compare all the fields. + c.clusterHandler.constructClusterUpdate() return } @@ -230,7 +266,7 @@ func (c *clusterNode) handleResp(clusterUpdate xdsclient.ClusterUpdate, err erro for child := range newChildren { if _, inChildrenAlready := mapCurrentChildren[child]; !inChildrenAlready { createdChild = true - mapCurrentChildren[child] = createClusterNode(child, c.clusterHandler.xdsClient, c.clusterHandler) + mapCurrentChildren[child] = createClusterNode(child, c.clusterHandler.parent.xdsClient, c.clusterHandler) } } diff --git a/xds/internal/balancer/cdsbalancer/cluster_handler_test.go b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go index 049cdf55ee6a..dc69dd34e2af 100644 --- a/xds/internal/balancer/cdsbalancer/cluster_handler_test.go +++ b/xds/internal/balancer/cdsbalancer/cluster_handler_test.go @@ -24,8 +24,8 @@ import ( "testing" "github.com/google/go-cmp/cmp" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -40,13 +40,7 @@ const ( // xds client. func setupTests(t *testing.T) (*clusterHandler, *fakeclient.Client) { xdsC := fakeclient.NewClient() - ch := &clusterHandler{ - xdsClient: xdsC, - // This is will be how the update channel is created in cds. It will be - // a separate channel to the buffer.Unbounded. This channel will also be - // read from to test any cluster updates. - updateChannel: make(chan clusterHandlerUpdate, 1), - } + ch := newClusterHandler(&cdsBalancer{xdsClient: xdsC}) return ch, xdsC } @@ -101,7 +95,7 @@ func (s) TestSuccessCaseLeafNode(t *testing.T) { fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) select { case chu := <-ch.updateChannel: - if diff := cmp.Diff(chu.chu, []xdsclient.ClusterUpdate{test.clusterUpdate}); diff != "" { + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.clusterUpdate}); diff != "" { t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) } case <-ctx.Done(): @@ -176,15 +170,15 @@ func (s) TestSuccessCaseLeafNodeThenNewUpdate(t *testing.T) { t.Fatal("Timed out waiting for update from updateChannel.") } - // Check that sending the same cluster update does not induce a - // update to be written to update buffer. + // Check that sending the same cluster update also induces an update + // to be written to update buffer. fakeClient.InvokeWatchClusterCallback(test.clusterUpdate, nil) shouldNotHappenCtx, shouldNotHappenCtxCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) defer shouldNotHappenCtxCancel() select { case <-ch.updateChannel: - t.Fatal("Should not have written an update to update buffer, as cluster update did not change.") case <-shouldNotHappenCtx.Done(): + t.Fatal("Timed out waiting for update from updateChannel.") } // Above represents same thing as the simple @@ -195,7 +189,7 @@ func (s) TestSuccessCaseLeafNodeThenNewUpdate(t *testing.T) { fakeClient.InvokeWatchClusterCallback(test.newClusterUpdate, nil) select { case chu := <-ch.updateChannel: - if diff := cmp.Diff(chu.chu, []xdsclient.ClusterUpdate{test.newClusterUpdate}); diff != "" { + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{test.newClusterUpdate}); diff != "" { t.Fatalf("got unexpected cluster update, diff (-got, +want): %v", diff) } case <-ctx.Done(): @@ -311,7 +305,7 @@ func (s) TestUpdateRootClusterAggregateSuccess(t *testing.T) { // ordered as per the cluster update. select { case chu := <-ch.updateChannel: - if diff := cmp.Diff(chu.chu, []xdsclient.ClusterUpdate{{ + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ ClusterType: xdsclient.ClusterTypeEDS, ClusterName: edsService, }, { @@ -418,7 +412,7 @@ func (s) TestUpdateRootClusterAggregateThenChangeChild(t *testing.T) { select { case chu := <-ch.updateChannel: - if diff := cmp.Diff(chu.chu, []xdsclient.ClusterUpdate{{ + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ ClusterType: xdsclient.ClusterTypeEDS, ClusterName: edsService, }, { @@ -664,7 +658,7 @@ func (s) TestSwitchClusterNodeBetweenLeafAndAggregated(t *testing.T) { // Then an update should successfully be written to the update buffer. select { case chu := <-ch.updateChannel: - if diff := cmp.Diff(chu.chu, []xdsclient.ClusterUpdate{{ + if diff := cmp.Diff(chu.updates, []xdsclient.ClusterUpdate{{ ClusterType: xdsclient.ClusterTypeEDS, ClusterName: edsService2, }}); diff != "" { diff --git a/xds/internal/balancer/clusterimpl/balancer_test.go b/xds/internal/balancer/clusterimpl/balancer_test.go index cfce6f673913..e583b647473d 100644 --- a/xds/internal/balancer/clusterimpl/balancer_test.go +++ b/xds/internal/balancer/clusterimpl/balancer_test.go @@ -22,6 +22,7 @@ package clusterimpl import ( "context" + "fmt" "strings" "testing" "time" @@ -34,17 +35,20 @@ import ( "google.golang.org/grpc/internal" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( - defaultTestTimeout = 1 * time.Second - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" + defaultTestTimeout = 1 * time.Second + defaultShortTestTimeout = 100 * time.Microsecond + + testClusterName = "test-cluster" + testServiceName = "test-eds-service" + testLRSServerName = "test-lrs-name" ) var ( @@ -72,11 +76,9 @@ func init() { // TestDropByCategory verifies that the balancer correctly drops the picks, and // that the drops are reported. func TestDropByCategory(t *testing.T) { - defer client.ClearCounterForTesting(testClusterName) + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -89,9 +91,7 @@ func TestDropByCategory(t *testing.T) { dropDenominator = 2 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -162,6 +162,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount, Drops: map[string]uint64{dropReason: dropCount}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -176,9 +179,7 @@ func TestDropByCategory(t *testing.T) { dropDenominator2 = 4 ) if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -219,6 +220,9 @@ func TestDropByCategory(t *testing.T) { Service: testServiceName, TotalDrops: dropCount2, Drops: map[string]uint64{dropReason2: dropCount2}, + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: rpcCount - dropCount2}}, + }, }} gotStatsData1 := loadStore.Stats([]string{testClusterName}) @@ -230,11 +234,9 @@ func TestDropByCategory(t *testing.T) { // TestDropCircuitBreaking verifies that the balancer correctly drops the picks // due to circuit breaking, and that the drops are reported. func TestDropCircuitBreaking(t *testing.T) { - defer client.ClearCounterForTesting(testClusterName) + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -243,9 +245,7 @@ func TestDropCircuitBreaking(t *testing.T) { var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -330,6 +330,9 @@ func TestDropCircuitBreaking(t *testing.T) { Cluster: testClusterName, Service: testServiceName, TotalDrops: uint64(maxRequest), + LocalityStats: map[string]load.LocalityData{ + assertString(xdsinternal.LocalityID{}.ToString): {RequestStats: load.RequestData{Succeeded: uint64(rpcCount - maxRequest + 50)}}, + }, }} gotStatsData0 := loadStore.Stats([]string{testClusterName}) @@ -342,11 +345,9 @@ func TestDropCircuitBreaking(t *testing.T) { // picker after it's closed. Because picker updates are sent in the run() // goroutine. func TestPickerUpdateAfterClose(t *testing.T) { - defer client.ClearCounterForTesting(testClusterName) + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -354,9 +355,7 @@ func TestPickerUpdateAfterClose(t *testing.T) { var maxRequest uint32 = 50 if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -387,11 +386,9 @@ func TestPickerUpdateAfterClose(t *testing.T) { // TestClusterNameInAddressAttributes covers the case that cluster name is // attached to the subconn address attributes. func TestClusterNameInAddressAttributes(t *testing.T) { - defer client.ClearCounterForTesting(testClusterName) + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -399,9 +396,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) { defer b.Close() if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -450,9 +445,7 @@ func TestClusterNameInAddressAttributes(t *testing.T) { const testClusterName2 = "test-cluster-2" var addr2 = resolver.Address{Addr: "2.2.2.2"} if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: []resolver.Address{addr2}, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: []resolver.Address{addr2}}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName2, EDSServiceName: testServiceName, @@ -478,11 +471,9 @@ func TestClusterNameInAddressAttributes(t *testing.T) { // TestReResolution verifies that when a SubConn turns transient failure, // re-resolution is triggered. func TestReResolution(t *testing.T) { - defer client.ClearCounterForTesting(testClusterName) + defer xdsclient.ClearCounterForTesting(testClusterName, testServiceName) xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() + defer xdsC.Close() builder := balancer.Get(Name) cc := testutils.NewTestClientConn(t) @@ -490,9 +481,7 @@ func TestReResolution(t *testing.T) { defer b.Close() if err := b.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, + ResolverState: xdsclient.SetClient(resolver.State{Addresses: testBackendAddrs}, xdsC), BalancerConfig: &LBConfig{ Cluster: testClusterName, EDSServiceName: testServiceName, @@ -557,3 +546,220 @@ func TestReResolution(t *testing.T) { t.Fatalf("timeout waiting for ResolveNow()") } } + +func TestLoadReporting(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + // Locality: testLocality, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) + } + + sc1 := <-cc.NewSubConnCh + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + // This should get the connecting picker. + p0 := <-cc.NewPickerCh + for i := 0; i < 10; i++ { + _, err := p0.Pick(balancer.PickInfo{}) + if err != balancer.ErrNoSubConnAvailable { + t.Fatalf("picker.Pick, got _,%v, want Err=%v", err, balancer.ErrNoSubConnAvailable) + } + } + + b.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + // Test pick with one backend. + p1 := <-cc.NewPickerCh + const successCount = 5 + for i := 0; i < successCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{}) + } + const errorCount = 5 + for i := 0; i < errorCount; i++ { + gotSCSt, err := p1.Pick(balancer.PickInfo{}) + if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("picker.Pick, got %v, %v, want SubConn=%v", gotSCSt, err, sc1) + } + gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) + } + + // Dump load data from the store and compare with expected counts. + loadStore := xdsC.LoadStore() + if loadStore == nil { + t.Fatal("loadStore is nil in xdsClient") + } + sds := loadStore.Stats([]string{testClusterName}) + if len(sds) == 0 { + t.Fatalf("loads for cluster %v not found in store", testClusterName) + } + sd := sds[0] + if sd.Cluster != testClusterName || sd.Service != testServiceName { + t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) + } + testLocalityJSON, _ := testLocality.ToString() + localityData, ok := sd.LocalityStats[testLocalityJSON] + if !ok { + t.Fatalf("loads for %v not found in store", testLocality) + } + reqStats := localityData.RequestStats + if reqStats.Succeeded != successCount { + t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) + } + if reqStats.Errored != errorCount { + t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) + } + if reqStats.InProgress != 0 { + t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) + } + + b.Close() + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } +} + +// TestUpdateLRSServer covers the cases +// - the init config specifies "" as the LRS server +// - config modifies LRS server to a different string +// - config sets LRS server to nil to stop load reporting +func TestUpdateLRSServer(t *testing.T) { + var testLocality = xdsinternal.LocalityID{ + Region: "test-region", + Zone: "test-zone", + SubZone: "test-sub-zone", + } + + xdsC := fakeclient.NewClient() + defer xdsC.Close() + + builder := balancer.Get(Name) + cc := testutils.NewTestClientConn(t) + b := builder.Build(cc, balancer.BuildOptions{}) + defer b.Close() + + addrs := make([]resolver.Address, len(testBackendAddrs)) + for i, a := range testBackendAddrs { + addrs[i] = xdsinternal.SetLocalityID(a, testLocality) + } + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(""), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + got, err := xdsC.WaitForReportLoad(ctx) + if err != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) + } + if got.Server != "" { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, "") + } + + // Update LRS server to a different name. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: newString(testLRSServerName), + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + got2, err2 := xdsC.WaitForReportLoad(ctx) + if err2 != nil { + t.Fatalf("xdsClient.ReportLoad failed with error: %v", err2) + } + if got2.Server != testLRSServerName { + t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got2.Server, testLRSServerName) + } + + // Update LRS server to nil, to disable LRS. + if err := b.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{Addresses: addrs}, xdsC), + BalancerConfig: &LBConfig{ + Cluster: testClusterName, + EDSServiceName: testServiceName, + LoadReportingServerName: nil, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatalf("unexpected error from UpdateClientConnState: %v", err) + } + if err := xdsC.WaitForCancelReportLoad(ctx); err != nil { + t.Fatalf("unexpected error waiting form load report to be canceled: %v", err) + } + + shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultShortTestTimeout) + defer shortCancel() + if s, err := xdsC.WaitForReportLoad(shortCtx); err != context.DeadlineExceeded { + t.Fatalf("unexpected load report to server: %q", s) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} diff --git a/xds/internal/balancer/clusterimpl/clusterimpl.go b/xds/internal/balancer/clusterimpl/clusterimpl.go index 4bd29901d760..1b49dccbc633 100644 --- a/xds/internal/balancer/clusterimpl/clusterimpl.go +++ b/xds/internal/balancer/clusterimpl/clusterimpl.go @@ -27,6 +27,7 @@ import ( "encoding/json" "fmt" "sync" + "sync/atomic" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" @@ -37,9 +38,10 @@ import ( "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) const ( @@ -49,51 +51,36 @@ const ( ) func init() { - balancer.Register(clusterImplBB{}) + balancer.Register(bb{}) } -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +type bb struct{} -type clusterImplBB struct{} - -func (clusterImplBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &clusterImplBalancer{ ClientConn: cc, bOpts: bOpts, closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), loadWrapper: loadstore.NewWrapper(), + scWrappers: make(map[balancer.SubConn]*scWrapper), pickerUpdateCh: buffer.NewUnbounded(), requestCountMax: defaultRequestCountMax, } b.logger = prefixLogger(b) - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.xdsC = client go b.run() - b.logger.Infof("Created") return b } -func (clusterImplBB) Name() string { +func (bb) Name() string { return Name } -func (clusterImplBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - type clusterImplBalancer struct { balancer.ClientConn @@ -107,47 +94,64 @@ type clusterImplBalancer struct { // synchronized with Close(). mu sync.Mutex closed *grpcsync.Event + done *grpcsync.Event - bOpts balancer.BuildOptions - logger *grpclog.PrefixLogger - xdsC xdsClientInterface + bOpts balancer.BuildOptions + logger *grpclog.PrefixLogger + xdsClient xdsclient.XDSClient config *LBConfig childLB balancer.Balancer cancelLoadReport func() edsServiceName string - lrsServerName string + lrsServerName *string loadWrapper *loadstore.Wrapper clusterNameMu sync.Mutex clusterName string - // childState/drops/requestCounter can only be accessed in run(). And run() - // is the only goroutine that sends picker to the parent ClientConn. All + scWrappersMu sync.Mutex + // The SubConns passed to the child policy are wrapped in a wrapper, to keep + // locality ID. But when the parent ClientConn sends updates, it's going to + // give the original SubConn, not the wrapper. But the child policies only + // know about the wrapper, so when forwarding SubConn updates, they must be + // sent for the wrappers. + // + // This keeps a map from original SubConn to wrapper, so that when + // forwarding the SubConn state update, the child policy will get the + // wrappers. + scWrappers map[balancer.SubConn]*scWrapper + + // childState/drops/requestCounter keeps the state used by the most recently + // generated picker. All fields can only be accessed in run(). And run() is + // the only goroutine that sends picker to the parent ClientConn. All // requests to update picker need to be sent to pickerUpdateCh. - childState balancer.State - drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter - requestCountMax uint32 - pickerUpdateCh *buffer.Unbounded + childState balancer.State + dropCategories []DropConfig // The categories for drops. + drops []*dropper + requestCounterCluster string // The cluster name for the request counter. + requestCounterService string // The service name for the request counter. + requestCounter *xdsclient.ClusterRequestsCounter + requestCountMax uint32 + pickerUpdateCh *buffer.Unbounded } // updateLoadStore checks the config for load store, and decides whether it // needs to restart the load reporting stream. -func (cib *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error { +func (b *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error { var updateLoadClusterAndService bool // ClusterName is different, restart. ClusterName is from ClusterName and // EDSServiceName. - clusterName := cib.getClusterName() + clusterName := b.getClusterName() if clusterName != newConfig.Cluster { updateLoadClusterAndService = true - cib.setClusterName(newConfig.Cluster) + b.setClusterName(newConfig.Cluster) clusterName = newConfig.Cluster } - if cib.edsServiceName != newConfig.EDSServiceName { + if b.edsServiceName != newConfig.EDSServiceName { updateLoadClusterAndService = true - cib.edsServiceName = newConfig.EDSServiceName + b.edsServiceName = newConfig.EDSServiceName } if updateLoadClusterAndService { // This updates the clusterName and serviceName that will be reported @@ -158,39 +162,65 @@ func (cib *clusterImplBalancer) updateLoadStore(newConfig *LBConfig) error { // On the other hand, this will almost never happen. Each LRS policy // shouldn't get updated config. The parent should do a graceful switch // when the clusterName or serviceName is changed. - cib.loadWrapper.UpdateClusterAndService(clusterName, cib.edsServiceName) + b.loadWrapper.UpdateClusterAndService(clusterName, b.edsServiceName) } + var ( + stopOldLoadReport bool + startNewLoadReport bool + ) + // Check if it's necessary to restart load report. - var newLRSServerName string - if newConfig.LoadReportingServerName != nil { - newLRSServerName = *newConfig.LoadReportingServerName - } - if cib.lrsServerName != newLRSServerName { - // LoadReportingServerName is different, load should be report to a - // different server, restart. - cib.lrsServerName = newLRSServerName - if cib.cancelLoadReport != nil { - cib.cancelLoadReport() - cib.cancelLoadReport = nil + if b.lrsServerName == nil { + if newConfig.LoadReportingServerName != nil { + // Old is nil, new is not nil, start new LRS. + b.lrsServerName = newConfig.LoadReportingServerName + startNewLoadReport = true } + // Old is nil, new is nil, do nothing. + } else if newConfig.LoadReportingServerName == nil { + // Old is not nil, new is nil, stop old, don't start new. + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + } else { + // Old is not nil, new is not nil, compare string values, if + // different, stop old and start new. + if *b.lrsServerName != *newConfig.LoadReportingServerName { + b.lrsServerName = newConfig.LoadReportingServerName + stopOldLoadReport = true + startNewLoadReport = true + } + } + + if stopOldLoadReport { + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + if !startNewLoadReport { + // If a new LRS stream will be started later, no need to update + // it to nil here. + b.loadWrapper.UpdateLoadStore(nil) + } + } + } + if startNewLoadReport { var loadStore *load.Store - if cib.xdsC != nil { - loadStore, cib.cancelLoadReport = cib.xdsC.ReportLoad(cib.lrsServerName) + if b.xdsClient != nil { + loadStore, b.cancelLoadReport = b.xdsClient.ReportLoad(*b.lrsServerName) } - cib.loadWrapper.UpdateLoadStore(loadStore) + b.loadWrapper.UpdateLoadStore(loadStore) } return nil } -func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) +func (b *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterImplBalancer was closed", s) return nil } - cib.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) @@ -204,59 +234,32 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) } + if b.xdsClient == nil { + c := xdsclient.FromResolverState(s.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + } + // Update load reporting config. This needs to be done before updating the // child policy because we need the loadStore from the updated client to be // passed to the ccWrapper, so that the next picker from the child policy // will pick up the new loadStore. - if err := cib.updateLoadStore(newConfig); err != nil { + if err := b.updateLoadStore(newConfig); err != nil { return err } - // Compare new drop config. And update picker if it's changed. - var updatePicker bool - if cib.config == nil || !equalDropCategories(cib.config.DropCategories, newConfig.DropCategories) { - cib.drops = make([]*dropper, 0, len(newConfig.DropCategories)) - for _, c := range newConfig.DropCategories { - cib.drops = append(cib.drops, newDropper(c)) - } - updatePicker = true - } - - // Compare cluster name. And update picker if it's changed, because circuit - // breaking's stream counter will be different. - if cib.config == nil || cib.config.Cluster != newConfig.Cluster { - cib.requestCounter = xdsclient.GetServiceRequestsCounter(newConfig.Cluster) - updatePicker = true - } - // Compare upper bound of stream count. And update picker if it's changed. - // This is also for circuit breaking. - var newRequestCountMax uint32 = 1024 - if newConfig.MaxConcurrentRequests != nil { - newRequestCountMax = *newConfig.MaxConcurrentRequests - } - if cib.requestCountMax != newRequestCountMax { - cib.requestCountMax = newRequestCountMax - updatePicker = true - } - - if updatePicker { - cib.pickerUpdateCh.Put(&dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }) - } - // If child policy is a different type, recreate the sub-balancer. - if cib.config == nil || cib.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - if cib.childLB != nil { - cib.childLB.Close() + if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { + if b.childLB != nil { + b.childLB.Close() } - cib.childLB = bb.Build(cib, cib.bOpts) + b.childLB = bb.Build(b, b.bOpts) } - cib.config = newConfig + b.config = newConfig - if cib.childLB == nil { + if b.childLB == nil { // This is not an expected situation, and should be super rare in // practice. // @@ -266,27 +269,31 @@ func (cib *clusterImplBalancer) UpdateClientConnState(s balancer.ClientConnState return fmt.Errorf("child policy is nil, this means balancer %q's Build() returned nil", newConfig.ChildPolicy.Name) } + // Notify run() of this new config, in case drop and request counter need + // update (which means a new picker needs to be generated). + b.pickerUpdateCh.Put(newConfig) + // Addresses and sub-balancer config are sent to sub-balancer. - return cib.childLB.UpdateClientConnState(balancer.ClientConnState{ + return b.childLB.UpdateClientConnState(balancer.ClientConnState{ ResolverState: s.ResolverState, - BalancerConfig: cib.config.ChildPolicy.Config, + BalancerConfig: b.config.ChildPolicy.Config, }) } -func (cib *clusterImplBalancer) ResolverError(err error) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) +func (b *clusterImplBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%+v} after clusterImplBalancer was closed", err) return } - if cib.childLB != nil { - cib.childLB.ResolverError(err) + if b.childLB != nil { + b.childLB.ResolverError(err) } } -func (cib *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if cib.closed.HasFired() { - cib.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) +func (b *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subconn state change {%+v, %+v} after clusterImplBalancer was closed", sc, s) return } @@ -298,103 +305,223 @@ func (cib *clusterImplBalancer) UpdateSubConnState(sc balancer.SubConn, s balanc // knows). The parent priority policy is configured to ignore re-resolution // signal from the EDS children. if s.ConnectivityState == connectivity.TransientFailure { - cib.ClientConn.ResolveNow(resolver.ResolveNowOptions{}) + b.ClientConn.ResolveNow(resolver.ResolveNowOptions{}) } - if cib.childLB != nil { - cib.childLB.UpdateSubConnState(sc, s) + b.scWrappersMu.Lock() + if scw, ok := b.scWrappers[sc]; ok { + sc = scw + if s.ConnectivityState == connectivity.Shutdown { + // Remove this SubConn from the map on Shutdown. + delete(b.scWrappers, scw.SubConn) + } + } + b.scWrappersMu.Unlock() + if b.childLB != nil { + b.childLB.UpdateSubConnState(sc, s) } } -func (cib *clusterImplBalancer) Close() { - cib.mu.Lock() - cib.closed.Fire() - cib.mu.Unlock() +func (b *clusterImplBalancer) Close() { + b.mu.Lock() + b.closed.Fire() + b.mu.Unlock() - if cib.childLB != nil { - cib.childLB.Close() - cib.childLB = nil + if b.childLB != nil { + b.childLB.Close() + b.childLB = nil } - cib.xdsC.Close() - cib.logger.Infof("Shutdown") + <-b.done.Done() + b.logger.Infof("Shutdown") } // Override methods to accept updates from the child LB. -func (cib *clusterImplBalancer) UpdateState(state balancer.State) { +func (b *clusterImplBalancer) UpdateState(state balancer.State) { // Instead of updating parent ClientConn inline, send state to run(). - cib.pickerUpdateCh.Put(state) + b.pickerUpdateCh.Put(state) +} + +func (b *clusterImplBalancer) setClusterName(n string) { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + b.clusterName = n +} + +func (b *clusterImplBalancer) getClusterName() string { + b.clusterNameMu.Lock() + defer b.clusterNameMu.Unlock() + return b.clusterName +} + +// scWrapper is a wrapper of SubConn with locality ID. The locality ID can be +// retrieved from the addresses when creating SubConn. +// +// All SubConns passed to the child policies are wrapped in this, so that the +// picker can get the localityID from the picked SubConn, and do load reporting. +// +// After wrapping, all SubConns to and from the parent ClientConn (e.g. for +// SubConn state update, update/remove SubConn) must be the original SubConns. +// All SubConns to and from the child policy (NewSubConn, forwarding SubConn +// state update) must be the wrapper. The balancer keeps a map from the original +// SubConn to the wrapper for this purpose. +type scWrapper struct { + balancer.SubConn + // locality needs to be atomic because it can be updated while being read by + // the picker. + locality atomic.Value // type xdsinternal.LocalityID } -func (cib *clusterImplBalancer) setClusterName(n string) { - cib.clusterNameMu.Lock() - defer cib.clusterNameMu.Unlock() - cib.clusterName = n +func (scw *scWrapper) updateLocalityID(lID xdsinternal.LocalityID) { + scw.locality.Store(lID) } -func (cib *clusterImplBalancer) getClusterName() string { - cib.clusterNameMu.Lock() - defer cib.clusterNameMu.Unlock() - return cib.clusterName +func (scw *scWrapper) localityID() xdsinternal.LocalityID { + lID, _ := scw.locality.Load().(xdsinternal.LocalityID) + return lID } -func (cib *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - clusterName := cib.getClusterName() +func (b *clusterImplBalancer) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { + clusterName := b.getClusterName() newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID for i, addr := range addrs { newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) } - return cib.ClientConn.NewSubConn(newAddrs, opts) + sc, err := b.ClientConn.NewSubConn(newAddrs, opts) + if err != nil { + return nil, err + } + // Wrap this SubConn in a wrapper, and add it to the map. + b.scWrappersMu.Lock() + ret := &scWrapper{SubConn: sc} + ret.updateLocalityID(lID) + b.scWrappers[sc] = ret + b.scWrappersMu.Unlock() + return ret, nil } -func (cib *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { - clusterName := cib.getClusterName() +func (b *clusterImplBalancer) RemoveSubConn(sc balancer.SubConn) { + scw, ok := sc.(*scWrapper) + if !ok { + b.ClientConn.RemoveSubConn(sc) + return + } + // Remove the original SubConn from the parent ClientConn. + // + // Note that we don't remove this SubConn from the scWrappers map. We will + // need it to forward the final SubConn state Shutdown to the child policy. + // + // This entry is kept in the map until it's state is changes to Shutdown, + // and will be deleted in UpdateSubConnState(). + b.ClientConn.RemoveSubConn(scw.SubConn) +} + +func (b *clusterImplBalancer) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { + clusterName := b.getClusterName() newAddrs := make([]resolver.Address, len(addrs)) + var lID xdsinternal.LocalityID for i, addr := range addrs { newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) + lID = xdsinternal.GetLocalityID(newAddrs[i]) } - cib.ClientConn.UpdateAddresses(sc, newAddrs) + if scw, ok := sc.(*scWrapper); ok { + scw.updateLocalityID(lID) + // Need to get the original SubConn from the wrapper before calling + // parent ClientConn. + sc = scw.SubConn + } + b.ClientConn.UpdateAddresses(sc, newAddrs) } type dropConfigs struct { drops []*dropper - requestCounter *xdsclient.ServiceRequestsCounter + requestCounter *xdsclient.ClusterRequestsCounter requestCountMax uint32 } -func (cib *clusterImplBalancer) run() { +// handleDropAndRequestCount compares drop and request counter in newConfig with +// the one currently used by picker. It returns a new dropConfigs if a new +// picker needs to be generated, otherwise it returns nil. +func (b *clusterImplBalancer) handleDropAndRequestCount(newConfig *LBConfig) *dropConfigs { + // Compare new drop config. And update picker if it's changed. + var updatePicker bool + if !equalDropCategories(b.dropCategories, newConfig.DropCategories) { + b.dropCategories = newConfig.DropCategories + b.drops = make([]*dropper, 0, len(newConfig.DropCategories)) + for _, c := range newConfig.DropCategories { + b.drops = append(b.drops, newDropper(c)) + } + updatePicker = true + } + + // Compare cluster name. And update picker if it's changed, because circuit + // breaking's stream counter will be different. + if b.requestCounterCluster != newConfig.Cluster || b.requestCounterService != newConfig.EDSServiceName { + b.requestCounterCluster = newConfig.Cluster + b.requestCounterService = newConfig.EDSServiceName + b.requestCounter = xdsclient.GetClusterRequestsCounter(newConfig.Cluster, newConfig.EDSServiceName) + updatePicker = true + } + // Compare upper bound of stream count. And update picker if it's changed. + // This is also for circuit breaking. + var newRequestCountMax uint32 = 1024 + if newConfig.MaxConcurrentRequests != nil { + newRequestCountMax = *newConfig.MaxConcurrentRequests + } + if b.requestCountMax != newRequestCountMax { + b.requestCountMax = newRequestCountMax + updatePicker = true + } + + if !updatePicker { + return nil + } + return &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + } +} + +func (b *clusterImplBalancer) run() { + defer b.done.Fire() for { select { - case update := <-cib.pickerUpdateCh.Get(): - cib.pickerUpdateCh.Load() - cib.mu.Lock() - if cib.closed.HasFired() { - cib.mu.Unlock() + case update := <-b.pickerUpdateCh.Get(): + b.pickerUpdateCh.Load() + b.mu.Lock() + if b.closed.HasFired() { + b.mu.Unlock() return } switch u := update.(type) { case balancer.State: - cib.childState = u - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, &dropConfigs{ - drops: cib.drops, - requestCounter: cib.requestCounter, - requestCountMax: cib.requestCountMax, - }, cib.loadWrapper), + b.childState = u + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, &dropConfigs{ + drops: b.drops, + requestCounter: b.requestCounter, + requestCountMax: b.requestCountMax, + }, b.loadWrapper), }) - case *dropConfigs: - cib.drops = u.drops - cib.requestCounter = u.requestCounter - if cib.childState.Picker != nil { - cib.ClientConn.UpdateState(balancer.State{ - ConnectivityState: cib.childState.ConnectivityState, - Picker: newDropPicker(cib.childState, u, cib.loadWrapper), + case *LBConfig: + dc := b.handleDropAndRequestCount(u) + if dc != nil && b.childState.Picker != nil { + b.ClientConn.UpdateState(balancer.State{ + ConnectivityState: b.childState.ConnectivityState, + Picker: newPicker(b.childState, dc, b.loadWrapper), }) } } - cib.mu.Unlock() - case <-cib.closed.Done(): + b.mu.Unlock() + case <-b.closed.Done(): + if b.cancelLoadReport != nil { + b.cancelLoadReport() + b.cancelLoadReport = nil + } return } } diff --git a/xds/internal/balancer/clusterimpl/picker.go b/xds/internal/balancer/clusterimpl/picker.go index 7a31615510d3..db29c550be11 100644 --- a/xds/internal/balancer/clusterimpl/picker.go +++ b/xds/internal/balancer/clusterimpl/picker.go @@ -19,13 +19,14 @@ package clusterimpl import ( + orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" "google.golang.org/grpc/balancer" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // NewRandomWRR is used when calculating drops. It's exported so that tests can @@ -66,21 +67,30 @@ func (d *dropper) drop() (ret bool) { return d.w.Next().(bool) } +const ( + serverLoadCPUName = "cpu_utilization" + serverLoadMemoryName = "mem_utilization" +) + // loadReporter wraps the methods from the loadStore that are used here. type loadReporter interface { + CallStarted(locality string) + CallFinished(locality string, err error) + CallServerLoad(locality, name string, val float64) CallDropped(locality string) } -type dropPicker struct { +// Picker implements RPC drop, circuit breaking drop and load reporting. +type picker struct { drops []*dropper s balancer.State loadStore loadReporter - counter *client.ServiceRequestsCounter + counter *xdsclient.ClusterRequestsCounter countMax uint32 } -func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *dropPicker { - return &dropPicker{ +func newPicker(s balancer.State, config *dropConfigs, loadStore load.PerClusterReporter) *picker { + return &picker{ drops: config.drops, s: s, loadStore: loadStore, @@ -89,13 +99,14 @@ func newDropPicker(s balancer.State, config *dropConfigs, loadStore load.PerClus } } -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { +func (d *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { // Don't drop unless the inner picker is READY. Similar to // https://github.com/grpc/grpc-go/issues/2622. if d.s.ConnectivityState != connectivity.Ready { return d.s.Picker.Pick(info) } + // Check if this RPC should be dropped by category. for _, dp := range d.drops { if dp.drop() { if d.loadStore != nil { @@ -105,6 +116,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } } + // Check if this RPC should be dropped by circuit breaking. if d.counter != nil { if err := d.counter.StartRequest(d.countMax); err != nil { // Drops by circuit breaking are reported with empty category. They @@ -114,11 +126,58 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { } return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) } - pr, err := d.s.Picker.Pick(info) - if err != nil { + } + + var lIDStr string + pr, err := d.s.Picker.Pick(info) + if scw, ok := pr.SubConn.(*scWrapper); ok { + // This OK check also covers the case err!=nil, because SubConn will be + // nil. + pr.SubConn = scw.SubConn + var e error + // If locality ID isn't found in the wrapper, an empty locality ID will + // be used. + lIDStr, e = scw.localityID().ToString() + if e != nil { + logger.Infof("failed to marshal LocalityID: %#v, loads won't be reported", scw.localityID()) + } + } + + if err != nil { + if d.counter != nil { + // Release one request count if this pick fails. d.counter.EndRequest() - return pr, err } + return pr, err + } + + if d.loadStore != nil { + d.loadStore.CallStarted(lIDStr) + oldDone := pr.Done + pr.Done = func(info balancer.DoneInfo) { + if oldDone != nil { + oldDone(info) + } + d.loadStore.CallFinished(lIDStr, info.Err) + + load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) + if !ok { + return + } + d.loadStore.CallServerLoad(lIDStr, serverLoadCPUName, load.CpuUtilization) + d.loadStore.CallServerLoad(lIDStr, serverLoadMemoryName, load.MemUtilization) + for n, c := range load.RequestCost { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + for n, c := range load.Utilization { + d.loadStore.CallServerLoad(lIDStr, n, c) + } + } + } + + if d.counter != nil { + // Update Done() so that when the RPC finishes, the request count will + // be released. oldDone := pr.Done pr.Done = func(doneInfo balancer.DoneInfo) { d.counter.EndRequest() @@ -126,8 +185,7 @@ func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { oldDone(doneInfo) } } - return pr, err } - return d.s.Picker.Pick(info) + return pr, err } diff --git a/xds/internal/balancer/clustermanager/clustermanager.go b/xds/internal/balancer/clustermanager/clustermanager.go index c00a9a16f458..211133d384e8 100644 --- a/xds/internal/balancer/clustermanager/clustermanager.go +++ b/xds/internal/balancer/clustermanager/clustermanager.go @@ -36,12 +36,12 @@ import ( const balancerName = "xds_cluster_manager_experimental" func init() { - balancer.Register(builder{}) + balancer.Register(bb{}) } -type builder struct{} +type bb struct{} -func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { b := &bal{} b.logger = prefixLogger(b) b.stateAggregator = newBalancerStateAggregator(cc, b.logger) @@ -52,11 +52,11 @@ func (builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balance return b } -func (builder) Name() string { +func (bb) Name() string { return balancerName } -func (builder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } diff --git a/xds/internal/balancer/clusterresolver/clusterresolver.go b/xds/internal/balancer/clusterresolver/clusterresolver.go new file mode 100644 index 000000000000..b48e3e97b716 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver.go @@ -0,0 +1,360 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package clusterresolver contains EDS balancer implementation. +package clusterresolver + +import ( + "encoding/json" + "errors" + "fmt" + + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/base" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/buffer" + "google.golang.org/grpc/internal/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/internal/pretty" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +// Name is the name of the cluster_resolver balancer. +const Name = "cluster_resolver_experimental" + +var ( + errBalancerClosed = errors.New("cdsBalancer is closed") + newChildBalancer = func(bb balancer.Builder, cc balancer.ClientConn, o balancer.BuildOptions) balancer.Balancer { + return bb.Build(cc, o) + } +) + +func init() { + balancer.Register(bb{}) +} + +type bb struct{} + +// Build helps implement the balancer.Builder interface. +func (bb) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + priorityBuilder := balancer.Get(priority.Name) + if priorityBuilder == nil { + logger.Errorf("priority balancer is needed but not registered") + return nil + } + priorityConfigParser, ok := priorityBuilder.(balancer.ConfigParser) + if !ok { + logger.Errorf("priority balancer builder is not a config parser") + return nil + } + + b := &clusterResolverBalancer{ + bOpts: opts, + updateCh: buffer.NewUnbounded(), + closed: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + + priorityBuilder: priorityBuilder, + priorityConfigParser: priorityConfigParser, + } + b.logger = prefixLogger(b) + b.logger.Infof("Created") + + b.resourceWatcher = newResourceResolver(b) + b.cc = &ccWrapper{ + ClientConn: cc, + resourceWatcher: b.resourceWatcher, + } + + go b.run() + return b +} + +func (bb) Name() string { + return Name +} + +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, fmt.Errorf("unable to unmarshal balancer config %s into cluster-resolver config, error: %v", string(c), err) + } + return &cfg, nil +} + +// ccUpdate wraps a clientConn update received from gRPC (pushed from the +// xdsResolver). +type ccUpdate struct { + state balancer.ClientConnState + err error +} + +// scUpdate wraps a subConn update received from gRPC. This is directly passed +// on to the child balancer. +type scUpdate struct { + subConn balancer.SubConn + state balancer.SubConnState +} + +// clusterResolverBalancer manages xdsClient and the actual EDS balancer implementation that +// does load balancing. +// +// It currently has only an clusterResolverBalancer. Later, we may add fallback. +type clusterResolverBalancer struct { + cc balancer.ClientConn + bOpts balancer.BuildOptions + updateCh *buffer.Unbounded // Channel for updates from gRPC. + resourceWatcher *resourceResolver + logger *grpclog.PrefixLogger + closed *grpcsync.Event + done *grpcsync.Event + + priorityBuilder balancer.Builder + priorityConfigParser balancer.ConfigParser + + config *LBConfig + configRaw *serviceconfig.ParseResult + xdsClient xdsclient.XDSClient // xDS client to watch EDS resource. + attrsWithClient *attributes.Attributes // Attributes with xdsClient attached to be passed to the child policies. + + child balancer.Balancer + priorities []priorityConfig + watchUpdateReceived bool +} + +// handleClientConnUpdate handles a ClientConnUpdate received from gRPC. Good +// updates lead to registration of EDS and DNS watches. Updates with error lead +// to cancellation of existing watch and propagation of the same error to the +// child balancer. +func (b *clusterResolverBalancer) handleClientConnUpdate(update *ccUpdate) { + // We first handle errors, if any, and then proceed with handling the + // update, only if the status quo has changed. + if err := update.err; err != nil { + b.handleErrorFromUpdate(err, true) + return + } + + b.logger.Infof("Receive update from resolver, balancer config: %v", pretty.ToJSON(update.state.BalancerConfig)) + cfg, _ := update.state.BalancerConfig.(*LBConfig) + if cfg == nil { + b.logger.Warningf("xds: unexpected LoadBalancingConfig type: %T", update.state.BalancerConfig) + return + } + + b.config = cfg + b.configRaw = update.state.ResolverState.ServiceConfig + b.resourceWatcher.updateMechanisms(cfg.DiscoveryMechanisms) + + if !b.watchUpdateReceived { + // If update was not received, wait for it. + return + } + // If eds resp was received before this, the child policy was created. We + // need to generate a new balancer config and send it to the child, because + // certain fields (unrelated to EDS watch) might have changed. + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy config: %v", err) + } +} + +// handleWatchUpdate handles a watch update from the xDS Client. Good updates +// lead to clientConn updates being invoked on the underlying child balancer. +func (b *clusterResolverBalancer) handleWatchUpdate(update *resourceUpdate) { + if err := update.err; err != nil { + b.logger.Warningf("Watch error from xds-client %p: %v", b.xdsClient, err) + b.handleErrorFromUpdate(err, false) + return + } + + b.logger.Infof("resource update: %+v", pretty.ToJSON(update.priorities)) + b.watchUpdateReceived = true + b.priorities = update.priorities + + // A new EDS update triggers new child configs (e.g. different priorities + // for the priority balancer), and new addresses (the endpoints come from + // the EDS response). + if err := b.updateChildConfig(); err != nil { + b.logger.Warningf("failed to update child policy's balancer config: %v", err) + } +} + +// updateChildConfig builds a balancer config from eb's cached eds resp and +// service config, and sends that to the child balancer. Note that it also +// generates the addresses, because the endpoints come from the EDS resp. +// +// If child balancer doesn't already exist, one will be created. +func (b *clusterResolverBalancer) updateChildConfig() error { + // Child was build when the first EDS resp was received, so we just build + // the config and addresses. + if b.child == nil { + b.child = newChildBalancer(b.priorityBuilder, b.cc, b.bOpts) + } + + childCfgBytes, addrs, err := buildPriorityConfigJSON(b.priorities, b.config.EndpointPickingPolicy) + if err != nil { + return fmt.Errorf("failed to build priority balancer config: %v", err) + } + childCfg, err := b.priorityConfigParser.ParseConfig(childCfgBytes) + if err != nil { + return fmt.Errorf("failed to parse generated priority balancer config, this should never happen because the config is generated: %v", err) + } + b.logger.Infof("build balancer config: %v", pretty.ToJSON(childCfg)) + return b.child.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: resolver.State{ + Addresses: addrs, + ServiceConfig: b.configRaw, + Attributes: b.attrsWithClient, + }, + BalancerConfig: childCfg, + }) +} + +// handleErrorFromUpdate handles both the error from parent ClientConn (from CDS +// balancer) and the error from xds client (from the watcher). fromParent is +// true if error is from parent ClientConn. +// +// If the error is connection error, it should be handled for fallback purposes. +// +// If the error is resource-not-found: +// - If it's from CDS balancer (shows as a resolver error), it means LDS or CDS +// resources were removed. The EDS watch should be canceled. +// - If it's from xds client, it means EDS resource were removed. The EDS +// watcher should keep watching. +// In both cases, the sub-balancers will be receive the error. +func (b *clusterResolverBalancer) handleErrorFromUpdate(err error, fromParent bool) { + b.logger.Warningf("Received error: %v", err) + if fromParent && xdsclient.ErrType(err) == xdsclient.ErrorTypeResourceNotFound { + // This is an error from the parent ClientConn (can be the parent CDS + // balancer), and is a resource-not-found error. This means the resource + // (can be either LDS or CDS) was removed. Stop the EDS watch. + b.resourceWatcher.stop() + } + if b.child != nil { + b.child.ResolverError(err) + } else { + // If eds balancer was never created, fail the RPCs with errors. + b.cc.UpdateState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: base.NewErrPicker(err), + }) + } + +} + +// run is a long-running goroutine which handles all updates from gRPC and +// xdsClient. All methods which are invoked directly by gRPC or xdsClient simply +// push an update onto a channel which is read and acted upon right here. +func (b *clusterResolverBalancer) run() { + for { + select { + case u := <-b.updateCh.Get(): + b.updateCh.Load() + switch update := u.(type) { + case *ccUpdate: + b.handleClientConnUpdate(update) + case *scUpdate: + // SubConn updates are simply handed over to the underlying + // child balancer. + if b.child == nil { + b.logger.Errorf("xds: received scUpdate {%+v} with no child balancer", update) + break + } + b.child.UpdateSubConnState(update.subConn, update.state) + } + case u := <-b.resourceWatcher.updateChannel: + b.handleWatchUpdate(u) + + // Close results in cancellation of the EDS watch and closing of the + // underlying child policy and is the only way to exit this goroutine. + case <-b.closed.Done(): + b.resourceWatcher.stop() + + if b.child != nil { + b.child.Close() + b.child = nil + } + // This is the *ONLY* point of return from this function. + b.logger.Infof("Shutdown") + b.done.Fire() + return + } + } +} + +// Following are methods to implement the balancer interface. + +// UpdateClientConnState receives the serviceConfig (which contains the +// clusterName to watch for in CDS) and the xdsClient object from the +// xdsResolver. +func (b *clusterResolverBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + if b.closed.HasFired() { + b.logger.Warningf("xds: received ClientConnState {%+v} after clusterResolverBalancer was closed", state) + return errBalancerClosed + } + + if b.xdsClient == nil { + c := xdsclient.FromResolverState(state.ResolverState) + if c == nil { + return balancer.ErrBadResolverState + } + b.xdsClient = c + b.attrsWithClient = state.ResolverState.Attributes + } + + b.updateCh.Put(&ccUpdate{state: state}) + return nil +} + +// ResolverError handles errors reported by the xdsResolver. +func (b *clusterResolverBalancer) ResolverError(err error) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received resolver error {%v} after clusterResolverBalancer was closed", err) + return + } + b.updateCh.Put(&ccUpdate{err: err}) +} + +// UpdateSubConnState handles subConn updates from gRPC. +func (b *clusterResolverBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + if b.closed.HasFired() { + b.logger.Warningf("xds: received subConn update {%v, %v} after clusterResolverBalancer was closed", sc, state) + return + } + b.updateCh.Put(&scUpdate{subConn: sc, state: state}) +} + +// Close closes the cdsBalancer and the underlying child balancer. +func (b *clusterResolverBalancer) Close() { + b.closed.Fire() + <-b.done.Done() +} + +// ccWrapper overrides ResolveNow(), so that re-resolution from the child +// policies will trigger the DNS resolver in cluster_resolver balancer. +type ccWrapper struct { + balancer.ClientConn + resourceWatcher *resourceResolver +} + +func (c *ccWrapper) ResolveNow(resolver.ResolveNowOptions) { + c.resourceWatcher.resolveNow() +} diff --git a/xds/internal/balancer/clusterresolver/clusterresolver_test.go b/xds/internal/balancer/clusterresolver/clusterresolver_test.go new file mode 100644 index 000000000000..e7d0cd347cb7 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/clusterresolver_test.go @@ -0,0 +1,500 @@ +// +build go1.12 + +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // V2 client registration. +) + +const ( + defaultTestTimeout = 1 * time.Second + defaultTestShortTimeout = 10 * time.Millisecond + testEDSServcie = "test-eds-service-name" + testClusterName = "test-cluster-name" +) + +var ( + // A non-empty endpoints update which is expected to be accepted by the EDS + // LB policy. + defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ + Localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, + ID: internal.LocalityID{Zone: "zone"}, + Priority: 1, + Weight: 100, + }, + }, + } +) + +func init() { + balancer.Register(bb{}) +} + +type s struct { + grpctest.Tester + + cleanup func() +} + +func (ss s) Teardown(t *testing.T) { + xdsclient.ClearAllCountersForTesting() + ss.Tester.Teardown(t) + if ss.cleanup != nil { + ss.cleanup() + } +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +const testBalancerNameFooBar = "foo.bar" + +func newNoopTestClientConn() *noopTestClientConn { + return &noopTestClientConn{} +} + +// noopTestClientConn is used in EDS balancer config update tests that only +// cover the config update handling, but not SubConn/load-balancing. +type noopTestClientConn struct { + balancer.ClientConn +} + +func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { + return nil, nil +} + +func (noopTestClientConn) Target() string { return testEDSServcie } + +type scStateChange struct { + sc balancer.SubConn + state balancer.SubConnState +} + +type fakeChildBalancer struct { + cc balancer.ClientConn + subConnState *testutils.Channel + clientConnState *testutils.Channel + resolverError *testutils.Channel +} + +func (f *fakeChildBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + f.clientConnState.Send(state) + return nil +} + +func (f *fakeChildBalancer) ResolverError(err error) { + f.resolverError.Send(err) +} + +func (f *fakeChildBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + f.subConnState.Send(&scStateChange{sc: sc, state: state}) +} + +func (f *fakeChildBalancer) Close() {} + +func (f *fakeChildBalancer) waitForClientConnStateChange(ctx context.Context) error { + _, err := f.clientConnState.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForResolverError(ctx context.Context) error { + _, err := f.resolverError.Receive(ctx) + if err != nil { + return err + } + return nil +} + +func (f *fakeChildBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { + val, err := f.subConnState.Receive(ctx) + if err != nil { + return err + } + gotState := val.(*scStateChange) + if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { + return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) + } + return nil +} + +func newFakeChildBalancer(cc balancer.ClientConn) balancer.Balancer { + return &fakeChildBalancer{ + cc: cc, + subConnState: testutils.NewChannelWithSize(10), + clientConnState: testutils.NewChannelWithSize(10), + resolverError: testutils.NewChannelWithSize(10), + } +} + +type fakeSubConn struct{} + +func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } +func (*fakeSubConn) Connect() { panic("implement me") } + +// waitForNewChildLB makes sure that a new child LB is created by the top-level +// clusterResolverBalancer. +func waitForNewChildLB(ctx context.Context, ch *testutils.Channel) (*fakeChildBalancer, error) { + val, err := ch.Receive(ctx) + if err != nil { + return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) + } + return val.(*fakeChildBalancer), nil +} + +// setup overrides the functions which are used to create the xdsClient and the +// edsLB, creates fake version of them and makes them available on the provided +// channels. The returned cancel function should be called by the test for +// cleanup. +func setup(childLBCh *testutils.Channel) (*fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + + origNewChildBalancer := newChildBalancer + newChildBalancer = func(_ balancer.Builder, cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { + childLB := newFakeChildBalancer(cc) + defer func() { childLBCh.Send(childLB) }() + return childLB + } + return xdsC, func() { + newChildBalancer = origNewChildBalancer + xdsC.Close() + } +} + +// TestSubConnStateChange verifies if the top-level clusterResolverBalancer passes on +// the subConnState to appropriate child balancer. +func (s) TestSubConnStateChange(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", defaultEndpointsUpdate, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + + fsc := &fakeSubConn{} + state := balancer.SubConnState{ConnectivityState: connectivity.Ready} + edsB.UpdateSubConnState(fsc, state) + if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { + t.Fatal(err) + } +} + +// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is +// handled correctly. +// +// If it's resource-not-found, watch will NOT be canceled, the EDS impl will +// receive an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromXDSClientUpdate(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, resourceErr) + // Even if error is resource not found, watch shouldn't be canceled, because + // this is an EDS resource removed (and xds client actually never sends this + // error, but we still handles it). + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should not trigger a new watch. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForWatchEDS(sCtx); err != context.DeadlineExceeded { + t.Fatal("got unexpected new EDS watch") + } +} + +// TestErrorFromResolver verifies that resolver errors are handled correctly. +// +// If it's resource-not-found, watch will be canceled, the EDS impl will receive +// an empty EDS update, and new RPCs will fail. +// +// If it's connection error, nothing will happen. This will need to change to +// handle fallback. +func (s) TestErrorFromResolver(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + edsLB, err := waitForNewChildLB(ctx, edsLBCh) + if err != nil { + t.Fatal(err) + } + if err := edsLB.waitForClientConnStateChange(ctx); err != nil { + t.Fatalf("EDS impl got unexpected update: %v", err) + } + + connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") + edsB.ResolverError(connectionErr) + + sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if _, err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { + t.Fatal("watch was canceled, want not canceled (timeout error)") + } + + sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) + defer sCancel() + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal("eds impl got EDS resp, want timeout error") + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "clusterResolverBalancer resource not found error") + edsB.ResolverError(resourceErr) + if _, err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { + t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) + } + if err := edsLB.waitForClientConnStateChange(sCtx); err != context.DeadlineExceeded { + t.Fatal(err) + } + if err := edsLB.waitForResolverError(ctx); err != nil { + t.Fatalf("want resolver error, got %v", err) + } + + // An update with the same service name should trigger a new watch, because + // the previous watch was canceled. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS(testEDSServcie), + }); err != nil { + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } +} + +// Given a list of resource names, verifies that EDS requests for the same are +// sent by the EDS balancer, through the fake xDS client. +func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { + for _, name := range resourceNames { + if name == "" { + // ResourceName empty string indicates a cancel. + if _, err := fc.WaitForCancelEDSWatch(ctx); err != nil { + return fmt.Errorf("timed out when expecting resource %q", name) + } + continue + } + + resName, err := fc.WaitForWatchEDS(ctx) + if err != nil { + return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) + } + if resName != name { + return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) + } + } + return nil +} + +// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB +// policy registers an EDS watch for expected resource upon receiving an update +// from gRPC. +func (s) TestClientWatchEDS(t *testing.T) { + edsLBCh := testutils.NewChannel() + xdsC, cleanup := setup(edsLBCh) + defer cleanup() + + builder := balancer.Get(Name) + edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsB == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + defer edsB.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + // If eds service name is not set, should watch for cluster name. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("cluster-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "cluster-1"); err != nil { + t.Fatal(err) + } + + // Update with an non-empty edsServiceName should trigger an EDS watch for + // the same. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-1"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-1"); err != nil { + t.Fatal(err) + } + + // Also test the case where the edsServerName changes from one non-empty + // name to another, and make sure a new watch is registered. The previously + // registered watch will be cancelled, which will result in an EDS request + // with no resource names being sent to the server. + if err := edsB.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: newLBConfigWithOneEDS("foobar-2"), + }); err != nil { + t.Fatal(err) + } + if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { + t.Fatal(err) + } +} + +func newLBConfigWithOneEDS(edsServiceName string) *LBConfig { + return &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: edsServiceName, + }}, + } +} diff --git a/xds/internal/balancer/clusterresolver/config.go b/xds/internal/balancer/clusterresolver/config.go new file mode 100644 index 000000000000..3bcb432b4091 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config.go @@ -0,0 +1,178 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/serviceconfig" +) + +// DiscoveryMechanismType is the type of discovery mechanism. +type DiscoveryMechanismType int + +const ( + // DiscoveryMechanismTypeEDS is eds. + DiscoveryMechanismTypeEDS DiscoveryMechanismType = iota // `json:"EDS"` + // DiscoveryMechanismTypeLogicalDNS is DNS. + DiscoveryMechanismTypeLogicalDNS // `json:"LOGICAL_DNS"` +) + +// MarshalJSON marshals a DiscoveryMechanismType to a quoted json string. +// +// This is necessary to handle enum (as strings) from JSON. +// +// Note that this needs to be defined on the type not pointer, otherwise the +// variables of this type will marshal to int not string. +func (t DiscoveryMechanismType) MarshalJSON() ([]byte, error) { + buffer := bytes.NewBufferString(`"`) + switch t { + case DiscoveryMechanismTypeEDS: + buffer.WriteString("EDS") + case DiscoveryMechanismTypeLogicalDNS: + buffer.WriteString("LOGICAL_DNS") + } + buffer.WriteString(`"`) + return buffer.Bytes(), nil +} + +// UnmarshalJSON unmarshals a quoted json string to the DiscoveryMechanismType. +func (t *DiscoveryMechanismType) UnmarshalJSON(b []byte) error { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + return err + } + switch s { + case "EDS": + *t = DiscoveryMechanismTypeEDS + case "LOGICAL_DNS": + *t = DiscoveryMechanismTypeLogicalDNS + default: + return fmt.Errorf("unable to unmarshal string %q to type DiscoveryMechanismType", s) + } + return nil +} + +// DiscoveryMechanism is the discovery mechanism, can be either EDS or DNS. +// +// For DNS, the ClientConn target will be used for name resolution. +// +// For EDS, if EDSServiceName is not empty, it will be used for watching. If +// EDSServiceName is empty, Cluster will be used. +type DiscoveryMechanism struct { + // Cluster is the cluster name. + Cluster string `json:"cluster,omitempty"` + // LoadReportingServerName is the LRS server to send load reports to. If + // not present, load reporting will be disabled. If set to the empty string, + // load reporting will be sent to the same server that we obtained CDS data + // from. + LoadReportingServerName *string `json:"lrsLoadReportingServerName,omitempty"` + // MaxConcurrentRequests is the maximum number of outstanding requests can + // be made to the upstream cluster. Default is 1024. + MaxConcurrentRequests *uint32 `json:"maxConcurrentRequests,omitempty"` + // Type is the discovery mechanism type. + Type DiscoveryMechanismType `json:"type,omitempty"` + // EDSServiceName is the EDS service name, as returned in CDS. May be unset + // if not specified in CDS. For type EDS only. + // + // This is used for EDS watch if set. If unset, Cluster is used for EDS + // watch. + EDSServiceName string `json:"edsServiceName,omitempty"` + // DNSHostname is the DNS name to resolve in "host:port" form. For type + // LOGICAL_DNS only. + DNSHostname string `json:"dnsHostname,omitempty"` +} + +// Equal returns whether the DiscoveryMechanism is the same with the parameter. +func (dm DiscoveryMechanism) Equal(b DiscoveryMechanism) bool { + switch { + case dm.Cluster != b.Cluster: + return false + case !equalStringP(dm.LoadReportingServerName, b.LoadReportingServerName): + return false + case !equalUint32P(dm.MaxConcurrentRequests, b.MaxConcurrentRequests): + return false + case dm.Type != b.Type: + return false + case dm.EDSServiceName != b.EDSServiceName: + return false + case dm.DNSHostname != b.DNSHostname: + return false + } + return true +} + +func equalStringP(a, b *string) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +func equalUint32P(a, b *uint32) bool { + if a == nil && b == nil { + return true + } + if a == nil || b == nil { + return false + } + return *a == *b +} + +// LBConfig is the config for cluster resolver balancer. +type LBConfig struct { + serviceconfig.LoadBalancingConfig `json:"-"` + // DiscoveryMechanisms is an ordered list of discovery mechanisms. + // + // Must have at least one element. Results from each discovery mechanism are + // concatenated together in successive priorities. + DiscoveryMechanisms []DiscoveryMechanism `json:"discoveryMechanisms,omitempty"` + + // LocalityPickingPolicy is policy for locality picking. + // + // This policy's config is expected to be in the format used by the + // weighted_target policy. Note that the config should include an empty + // value for the "targets" field; that empty value will be replaced by one + // that is dynamically generated based on the EDS data. Optional; defaults + // to "weighted_target". + LocalityPickingPolicy *internalserviceconfig.BalancerConfig `json:"localityPickingPolicy,omitempty"` + + // EndpointPickingPolicy is policy for endpoint picking. + // + // This will be configured as the policy for each child in the + // locality-policy's config. Optional; defaults to "round_robin". + EndpointPickingPolicy *internalserviceconfig.BalancerConfig `json:"endpointPickingPolicy,omitempty"` + + // TODO: read and warn if endpoint is not roundrobin or locality is not + // weightedtarget. +} + +func parseConfig(c json.RawMessage) (*LBConfig, error) { + var cfg LBConfig + if err := json.Unmarshal(c, &cfg); err != nil { + return nil, err + } + return &cfg, nil +} diff --git a/xds/internal/balancer/clusterresolver/config_test.go b/xds/internal/balancer/clusterresolver/config_test.go new file mode 100644 index 000000000000..77b14deb6abe --- /dev/null +++ b/xds/internal/balancer/clusterresolver/config_test.go @@ -0,0 +1,224 @@ +// +build go1.12 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" +) + +func TestDiscoveryMechanismTypeMarshalJSON(t *testing.T) { + tests := []struct { + name string + typ DiscoveryMechanismType + want string + }{ + { + name: "eds", + typ: DiscoveryMechanismTypeEDS, + want: `"EDS"`, + }, + { + name: "dns", + typ: DiscoveryMechanismTypeLogicalDNS, + want: `"LOGICAL_DNS"`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got, err := json.Marshal(tt.typ); err != nil || string(got) != tt.want { + t.Fatalf("DiscoveryMechanismTypeEDS.MarshalJSON() = (%v, %v), want (%s, nil)", string(got), err, tt.want) + } + }) + } +} +func TestDiscoveryMechanismTypeUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + js string + want DiscoveryMechanismType + wantErr bool + }{ + { + name: "eds", + js: `"EDS"`, + want: DiscoveryMechanismTypeEDS, + }, + { + name: "dns", + js: `"LOGICAL_DNS"`, + want: DiscoveryMechanismTypeLogicalDNS, + }, + { + name: "error", + js: `"1234"`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got DiscoveryMechanismType + err := json.Unmarshal([]byte(tt.js), &got) + if (err != nil) != tt.wantErr { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Fatalf("DiscoveryMechanismTypeEDS.UnmarshalJSON() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +const ( + testJSONConfig1 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }] +}` + testJSONConfig2 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + },{ + "type": "LOGICAL_DNS" + }] +}` + testJSONConfig3 = `{ + "discoveryMechanisms": [{ + "cluster": "test-cluster-name", + "lrsLoadReportingServerName": "test-lrs-server", + "maxConcurrentRequests": 314, + "type": "EDS", + "edsServiceName": "test-eds-service-name" + }], + "localityPickingPolicy":[{"pick_first":{}}], + "endpointPickingPolicy":[{"pick_first":{}}] +}` +) + +func TestParseConfig(t *testing.T) { + tests := []struct { + name string + js string + want *LBConfig + wantErr bool + }{ + { + name: "empty json", + js: "", + want: nil, + wantErr: true, + }, + { + name: "OK with one discovery mechanism", + js: testJSONConfig1, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + LocalityPickingPolicy: nil, + EndpointPickingPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with multiple discovery mechanisms", + js: testJSONConfig2, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + { + Type: DiscoveryMechanismTypeLogicalDNS, + }, + }, + LocalityPickingPolicy: nil, + EndpointPickingPolicy: nil, + }, + wantErr: false, + }, + { + name: "OK with picking policy override", + js: testJSONConfig3, + want: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{ + { + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServcie, + }, + }, + LocalityPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: "pick_first", + Config: nil, + }, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: "pick_first", + Config: nil, + }, + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseConfig([]byte(tt.js)) + if (err != nil) != tt.wantErr { + t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("parseConfig() got unexpected output, diff (-got +want): %v", diff) + } + }) + } +} + +func newString(s string) *string { + return &s +} + +func newUint32(i uint32) *uint32 { + return &i +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder.go b/xds/internal/balancer/clusterresolver/configbuilder.go new file mode 100644 index 000000000000..dfbdc1e2d671 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder.go @@ -0,0 +1,270 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "encoding/json" + "fmt" + "sort" + + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const million = 1000000 + +// priorityConfig is config for one priority. For example, if there an EDS and a +// DNS, the priority list will be [priorityConfig{EDS}, priorityConfig{DNS}]. +// +// Each priorityConfig corresponds to one discovery mechanism from the LBConfig +// generated by the CDS balancer. The CDS balancer resolves the cluster name to +// an ordered list of discovery mechanisms (if the top cluster is an aggregated +// cluster), one for each underlying cluster. +type priorityConfig struct { + mechanism DiscoveryMechanism + // edsResp is set only if type is EDS. + edsResp xdsclient.EndpointsUpdate + // addresses is set only if type is DNS. + addresses []string +} + +// buildPriorityConfigJSON builds balancer config for the passed in +// priorities. +// +// The built tree of balancers (see test for the output struct). +// +// ┌────────┐ +// │priority│ +// └┬──────┬┘ +// │ │ +// ┌───────────▼┐ ┌▼───────────┐ +// │cluster_impl│ │cluster_impl│ +// └─┬──────────┘ └──────────┬─┘ +// │ │ +// ┌──────────────▼─┐ ┌─▼──────────────┐ +// │locality_picking│ │locality_picking│ +// └┬──────────────┬┘ └┬──────────────┬┘ +// │ │ │ │ +// ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ ┌─▼─┐ +// │LRS│ │LRS│ │LRS│ │LRS│ +// └─┬─┘ └─┬─┘ └─┬─┘ └─┬─┘ +// │ │ │ │ +// ┌──────────▼─────┐ ┌─────▼──────────┐ ┌──────────▼─────┐ ┌─────▼──────────┐ +// │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ │endpoint_picking│ +// └────────────────┘ └────────────────┘ └────────────────┘ └────────────────┘ +// +// If endpointPickingPolicy is nil, roundrobin will be used. +// +// Custom locality picking policy isn't support, and weighted_target is always +// used. +// +// TODO: support setting locality picking policy, and add a parameter for +// locality picking policy. +func buildPriorityConfigJSON(priorities []priorityConfig, endpointPickingPolicy *internalserviceconfig.BalancerConfig) ([]byte, []resolver.Address, error) { + pc, addrs := buildPriorityConfig(priorities, endpointPickingPolicy) + ret, err := json.Marshal(pc) + if err != nil { + return nil, nil, fmt.Errorf("failed to marshal built priority config struct into json: %v", err) + } + return ret, addrs, nil +} + +func buildPriorityConfig(priorities []priorityConfig, endpointPickingPolicy *internalserviceconfig.BalancerConfig) (*priority.LBConfig, []resolver.Address) { + var ( + retConfig = &priority.LBConfig{Children: make(map[string]*priority.Child)} + retAddrs []resolver.Address + ) + for i, p := range priorities { + switch p.mechanism.Type { + case DiscoveryMechanismTypeEDS: + names, configs, addrs := buildClusterImplConfigForEDS(i, p.edsResp, p.mechanism, endpointPickingPolicy) + retConfig.Priorities = append(retConfig.Priorities, names...) + for n, c := range configs { + retConfig.Children[n] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: c}, + // Ignore all re-resolution from EDS children. + IgnoreReresolutionRequests: true, + } + } + retAddrs = append(retAddrs, addrs...) + case DiscoveryMechanismTypeLogicalDNS: + name, config, addrs := buildClusterImplConfigForDNS(i, p.addresses) + retConfig.Priorities = append(retConfig.Priorities, name) + retConfig.Children[name] = &priority.Child{ + Config: &internalserviceconfig.BalancerConfig{Name: clusterimpl.Name, Config: config}, + // Not ignore re-resolution from DNS children, they will trigger + // DNS to re-resolve. + IgnoreReresolutionRequests: false, + } + retAddrs = append(retAddrs, addrs...) + } + } + return retConfig, retAddrs +} + +func buildClusterImplConfigForDNS(parentPriority int, addrStrs []string) (string, *clusterimpl.LBConfig, []resolver.Address) { + // Endpoint picking policy for DNS is hardcoded to pick_first. + const childPolicy = "pick_first" + var retAddrs []resolver.Address + pName := fmt.Sprintf("priority-%v", parentPriority) + for _, addrStr := range addrStrs { + retAddrs = append(retAddrs, hierarchy.Set(resolver.Address{Addr: addrStr}, []string{pName})) + } + return pName, &clusterimpl.LBConfig{ChildPolicy: &internalserviceconfig.BalancerConfig{Name: childPolicy}}, retAddrs +} + +// buildClusterImplConfigForEDS returns a list of cluster_impl configs, one for +// each priority, sorted by priority, and the addresses for each priority (with +// hierarchy attributes set). +// +// For example, if there are two priorities, the returned values will be +// - ["p0", "p1"] +// - map{"p0":p0_config, "p1":p1_config} +// - [p0_address_0, p0_address_1, p1_address_0, p1_address_1] +// - p0 addresses' hierarchy attributes are set to p0 +func buildClusterImplConfigForEDS(parentPriority int, edsResp xdsclient.EndpointsUpdate, mechanism DiscoveryMechanism, endpointPickingPolicy *internalserviceconfig.BalancerConfig) ([]string, map[string]*clusterimpl.LBConfig, []resolver.Address) { + var ( + retNames []string + retAddrs []resolver.Address + retConfigs = make(map[string]*clusterimpl.LBConfig) + ) + + if endpointPickingPolicy == nil { + endpointPickingPolicy = &internalserviceconfig.BalancerConfig{Name: roundrobin.Name} + } + + drops := make([]clusterimpl.DropConfig, 0, len(edsResp.Drops)) + for _, d := range edsResp.Drops { + drops = append(drops, clusterimpl.DropConfig{ + Category: d.Category, + RequestsPerMillion: d.Numerator * million / d.Denominator, + }) + } + + priorityChildNames, priorities := groupLocalitiesByPriority(edsResp.Localities) + for _, priorityName := range priorityChildNames { + priorityLocalities := priorities[priorityName] + // Prepend parent priority to the priority names, to avoid duplicates. + pName := fmt.Sprintf("priority-%v-%v", parentPriority, priorityName) + retNames = append(retNames, pName) + wtConfig, addrs := localitiesToWeightedTarget(priorityLocalities, pName, endpointPickingPolicy, mechanism.Cluster, mechanism.EDSServiceName) + retConfigs[pName] = &clusterimpl.LBConfig{ + Cluster: mechanism.Cluster, + EDSServiceName: mechanism.EDSServiceName, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: weightedtarget.Name, Config: wtConfig}, + LoadReportingServerName: mechanism.LoadReportingServerName, + MaxConcurrentRequests: mechanism.MaxConcurrentRequests, + DropCategories: drops, + } + retAddrs = append(retAddrs, addrs...) + } + + return retNames, retConfigs, retAddrs +} + +// groupLocalitiesByPriority returns the localities grouped by priority. +// +// It also returns a list of strings where each string represents a priority, +// and the list is sorted from higher priority to lower priority. +// +// For example, for L0-p0, L1-p0, L2-p1, results will be +// - ["p0", "p1"] +// - map{"p0":[L0, L1], "p1":[L2]} +func groupLocalitiesByPriority(localities []xdsclient.Locality) ([]string, map[string][]xdsclient.Locality) { + var priorityIntSlice []int + priorities := make(map[string][]xdsclient.Locality) + for _, locality := range localities { + if locality.Weight == 0 { + continue + } + priorityName := fmt.Sprintf("%v", locality.Priority) + priorities[priorityName] = append(priorities[priorityName], locality) + priorityIntSlice = append(priorityIntSlice, int(locality.Priority)) + } + // Sort the priorities based on the int value, deduplicate, and then turn + // the sorted list into a string list. This will be child names, in priority + // order. + sort.Ints(priorityIntSlice) + priorityIntSliceDeduped := dedupSortedIntSlice(priorityIntSlice) + priorityNameSlice := make([]string, 0, len(priorityIntSliceDeduped)) + for _, p := range priorityIntSliceDeduped { + priorityNameSlice = append(priorityNameSlice, fmt.Sprintf("%v", p)) + } + return priorityNameSlice, priorities +} + +func dedupSortedIntSlice(a []int) []int { + if len(a) == 0 { + return a + } + i, j := 0, 1 + for ; j < len(a); j++ { + if a[i] == a[j] { + continue + } + i++ + if i != j { + a[i] = a[j] + } + } + return a[:i+1] +} + +// localitiesToWeightedTarget takes a list of localities (with the same +// priority), and generates a weighted target config, and list of addresses. +// +// The addresses have path hierarchy set to [priority-name, locality-name], so +// priority and weighted target know which child policy they are for. +func localitiesToWeightedTarget(localities []xdsclient.Locality, priorityName string, childPolicy *internalserviceconfig.BalancerConfig, cluster, edsService string) (*weightedtarget.LBConfig, []resolver.Address) { + weightedTargets := make(map[string]weightedtarget.Target) + var addrs []resolver.Address + for _, locality := range localities { + localityStr, err := locality.ID.ToString() + if err != nil { + localityStr = fmt.Sprintf("%+v", locality.ID) + } + weightedTargets[localityStr] = weightedtarget.Target{Weight: locality.Weight, ChildPolicy: childPolicy} + for _, endpoint := range locality.Endpoints { + // Filter out all "unhealthy" endpoints (unknown and healthy are + // both considered to be healthy: + // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). + if endpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && endpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { + continue + } + + addr := resolver.Address{Addr: endpoint.Address} + if childPolicy.Name == weightedroundrobin.Name && endpoint.Weight != 0 { + ai := weightedroundrobin.AddrInfo{Weight: endpoint.Weight} + addr = weightedroundrobin.SetAddrInfo(addr, ai) + } + addr = hierarchy.Set(addr, []string{priorityName, localityStr}) + addr = internal.SetLocalityID(addr, locality.ID) + addrs = append(addrs, addr) + } + } + return &weightedtarget.LBConfig{Targets: weightedTargets}, addrs +} diff --git a/xds/internal/balancer/clusterresolver/configbuilder_test.go b/xds/internal/balancer/clusterresolver/configbuilder_test.go new file mode 100644 index 000000000000..d8f17d053aae --- /dev/null +++ b/xds/internal/balancer/clusterresolver/configbuilder_test.go @@ -0,0 +1,746 @@ +// +build go1.12 + +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package clusterresolver + +import ( + "bytes" + "encoding/json" + "fmt" + "sort" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/attributes" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/balancer/weightedroundrobin" + "google.golang.org/grpc/internal/hierarchy" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +const ( + testLRSServer = "test-lrs-server" + testMaxRequests = 314 + testEDSServiceName = "service-name-from-parent" + testDropCategory = "test-drops" + testDropOverMillion = 1 + + localityCount = 5 + addressPerLocality = 2 +) + +var ( + testLocalityIDs []internal.LocalityID + testAddressStrs [][]string + testEndpoints [][]xdsclient.Endpoint + + testLocalitiesP0, testLocalitiesP1 []xdsclient.Locality + + addrCmpOpts = cmp.Options{ + cmp.AllowUnexported(attributes.Attributes{}), + cmp.Transformer("SortAddrs", func(in []resolver.Address) []resolver.Address { + out := append([]resolver.Address(nil), in...) // Copy input to avoid mutating it + sort.Slice(out, func(i, j int) bool { + return out[i].Addr < out[j].Addr + }) + return out + })} +) + +func init() { + for i := 0; i < localityCount; i++ { + testLocalityIDs = append(testLocalityIDs, internal.LocalityID{Zone: fmt.Sprintf("test-zone-%d", i)}) + var ( + addrs []string + ends []xdsclient.Endpoint + ) + for j := 0; j < addressPerLocality; j++ { + addr := fmt.Sprintf("addr-%d-%d", i, j) + addrs = append(addrs, addr) + ends = append(ends, xdsclient.Endpoint{ + Address: addr, + HealthStatus: xdsclient.EndpointHealthStatusHealthy, + }) + } + testAddressStrs = append(testAddressStrs, addrs) + testEndpoints = append(testEndpoints, ends) + } + + testLocalitiesP0 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, + } + testLocalitiesP1 = []xdsclient.Locality{ + { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, + } +} + +// TestBuildPriorityConfigJSON is a sanity check that the built balancer config +// can be parsed. The behavior test is covered by TestBuildPriorityConfig. +func TestBuildPriorityConfigJSON(t *testing.T) { + gotConfig, _, err := buildPriorityConfigJSON([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + if err != nil { + t.Fatalf("buildPriorityConfigJSON(...) failed: %v", err) + } + + var prettyGot bytes.Buffer + if err := json.Indent(&prettyGot, gotConfig, ">>> ", " "); err != nil { + t.Fatalf("json.Indent() failed: %v", err) + } + // Print the indented json if this test fails. + t.Log(prettyGot.String()) + + priorityB := balancer.Get(priority.Name) + if _, err = priorityB.(balancer.ConfigParser).ParseConfig(gotConfig); err != nil { + t.Fatalf("ParseConfig(%+v) failed: %v", gotConfig, err) + } +} + +func TestBuildPriorityConfig(t *testing.T) { + gotConfig, gotAddrs := buildPriorityConfig([]priorityConfig{ + { + mechanism: DiscoveryMechanism{ + Cluster: testClusterName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + edsResp: xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + testLocalitiesP0[0], + testLocalitiesP0[1], + testLocalitiesP1[0], + testLocalitiesP1[1], + }, + }, + }, + { + mechanism: DiscoveryMechanism{ + Type: DiscoveryMechanismTypeLogicalDNS, + }, + addresses: testAddressStrs[4], + }, + }, nil) + + wantConfig := &priority.LBConfig{ + Children: map[string]*priority.Child{ + "priority-0-0": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-0-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + }, + IgnoreReresolutionRequests: true, + }, + "priority-1": { + Config: &internalserviceconfig.BalancerConfig{ + Name: clusterimpl.Name, + Config: &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: "pick_first"}, + }, + }, + IgnoreReresolutionRequests: false, + }, + }, + Priorities: []string{"priority-0-0", "priority-0-1", "priority-1"}, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-0-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-0-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-0-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-0-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[4][0], nil, "priority-1", nil), + testAddrWithAttrs(testAddressStrs[4][1], nil, "priority-1", nil), + } + + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildPriorityConfig() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForDNS(t *testing.T) { + gotName, gotConfig, gotAddrs := buildClusterImplConfigForDNS(3, testAddressStrs[0]) + wantName := "priority-3" + wantConfig := &clusterimpl.LBConfig{ + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: "pick_first", + }, + } + wantAddrs := []resolver.Address{ + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][0]}, []string{"priority-3"}), + hierarchy.Set(resolver.Address{Addr: testAddressStrs[0][1]}, []string{"priority-3"}), + } + + if diff := cmp.Diff(gotName, wantName); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfig, wantConfig); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForDNS() diff (-got +want) %v", diff) + } +} + +func TestBuildClusterImplConfigForEDS(t *testing.T) { + gotNames, gotConfigs, gotAddrs := buildClusterImplConfigForEDS( + 2, + xdsclient.EndpointsUpdate{ + Drops: []xdsclient.OverloadDropConfig{ + { + Category: testDropCategory, + Numerator: testDropOverMillion, + Denominator: million, + }, + }, + Localities: []xdsclient.Locality{ + { + Endpoints: testEndpoints[3], + ID: testLocalityIDs[3], + Weight: 80, + Priority: 1, + }, { + Endpoints: testEndpoints[1], + ID: testLocalityIDs[1], + Weight: 80, + Priority: 0, + }, { + Endpoints: testEndpoints[2], + ID: testLocalityIDs[2], + Weight: 20, + Priority: 1, + }, { + Endpoints: testEndpoints[0], + ID: testLocalityIDs[0], + Weight: 20, + Priority: 0, + }, + }, + }, + DiscoveryMechanism{ + Cluster: testClusterName, + MaxConcurrentRequests: newUint32(testMaxRequests), + LoadReportingServerName: newString(testLRSServer), + Type: DiscoveryMechanismTypeEDS, + EDSServiceName: testEDSServiceName, + }, + nil, + ) + + wantNames := []string{ + fmt.Sprintf("priority-%v-%v", 2, 0), + fmt.Sprintf("priority-%v-%v", 2, 1), + } + wantConfigs := map[string]*clusterimpl.LBConfig{ + "priority-2-0": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[0].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[1].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + "priority-2-1": { + Cluster: testClusterName, + EDSServiceName: testEDSServiceName, + LoadReportingServerName: newString(testLRSServer), + MaxConcurrentRequests: newUint32(testMaxRequests), + DropCategories: []clusterimpl.DropConfig{ + { + Category: testDropCategory, + RequestsPerMillion: testDropOverMillion, + }, + }, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedtarget.Name, + Config: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(testLocalityIDs[2].ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(testLocalityIDs[3].ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + }, + }, + } + wantAddrs := []resolver.Address{ + testAddrWithAttrs(testAddressStrs[0][0], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[0][1], nil, "priority-2-0", &testLocalityIDs[0]), + testAddrWithAttrs(testAddressStrs[1][0], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[1][1], nil, "priority-2-0", &testLocalityIDs[1]), + testAddrWithAttrs(testAddressStrs[2][0], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[2][1], nil, "priority-2-1", &testLocalityIDs[2]), + testAddrWithAttrs(testAddressStrs[3][0], nil, "priority-2-1", &testLocalityIDs[3]), + testAddrWithAttrs(testAddressStrs[3][1], nil, "priority-2-1", &testLocalityIDs[3]), + } + + if diff := cmp.Diff(gotNames, wantNames); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotConfigs, wantConfigs); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(gotAddrs, wantAddrs, addrCmpOpts); diff != "" { + t.Errorf("buildClusterImplConfigForEDS() diff (-got +want) %v", diff) + } + +} + +func TestGroupLocalitiesByPriority(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + wantPriorities []string + wantLocalities map[string][]xdsclient.Locality + }{ + { + name: "1 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + }, + }, + { + name: "2 locality 1 priority", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP0[1]}, + wantPriorities: []string{"0"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + }, + }, + { + name: "1 locality in each", + localities: []xdsclient.Locality{testLocalitiesP0[0], testLocalitiesP1[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0]}, + "1": {testLocalitiesP1[0]}, + }, + }, + { + name: "2 localities in each sorted", + localities: []xdsclient.Locality{ + testLocalitiesP0[0], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP1[1]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[0], testLocalitiesP0[1]}, + "1": {testLocalitiesP1[0], testLocalitiesP1[1]}, + }, + }, + { + // The localities are given in order [p1, p0, p1, p0], but the + // returned priority list must be sorted [p0, p1], because the list + // order is the priority order. + name: "2 localities in each needs to sort", + localities: []xdsclient.Locality{ + testLocalitiesP1[1], testLocalitiesP0[1], + testLocalitiesP1[0], testLocalitiesP0[0]}, + wantPriorities: []string{"0", "1"}, + wantLocalities: map[string][]xdsclient.Locality{ + "0": {testLocalitiesP0[1], testLocalitiesP0[0]}, + "1": {testLocalitiesP1[1], testLocalitiesP1[0]}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotPriorities, gotLocalities := groupLocalitiesByPriority(tt.localities) + if diff := cmp.Diff(gotPriorities, tt.wantPriorities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + if diff := cmp.Diff(gotLocalities, tt.wantLocalities); diff != "" { + t.Errorf("groupLocalitiesByPriority() diff(-got +want) %v", diff) + } + }) + } +} + +func TestDedupSortedIntSlice(t *testing.T) { + tests := []struct { + name string + a []int + want []int + }{ + { + name: "empty", + a: []int{}, + want: []int{}, + }, + { + name: "no dup", + a: []int{0, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + { + name: "with dup", + a: []int{0, 0, 1, 1, 1, 2, 3}, + want: []int{0, 1, 2, 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := dedupSortedIntSlice(tt.a); !cmp.Equal(got, tt.want) { + t.Errorf("dedupSortedIntSlice() = %v, want %v, diff %v", got, tt.want, cmp.Diff(got, tt.want)) + } + }) + } +} + +func TestLocalitiesToWeightedTarget(t *testing.T) { + tests := []struct { + name string + localities []xdsclient.Locality + priorityName string + childPolicy *internalserviceconfig.BalancerConfig + lrsServer *string + cluster string + edsService string + wantConfig *weightedtarget.LBConfig + wantAddrs []resolver.Address + }{ + { + name: "roundrobin as child, with LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + lrsServer: newString("test-lrs-server"), + cluster: "test-cluster", + edsService: "test-eds-service", + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "roundrobin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: roundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", nil, "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + { + name: "weighted round robin as child, no LRS", + localities: []xdsclient.Locality{ + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-1-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-1-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-1"}, + Weight: 20, + }, + { + Endpoints: []xdsclient.Endpoint{ + {Address: "addr-2-1", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 90}, + {Address: "addr-2-2", HealthStatus: xdsclient.EndpointHealthStatusHealthy, Weight: 10}, + }, + ID: internal.LocalityID{Zone: "test-zone-2"}, + Weight: 80, + }, + }, + priorityName: "test-priority", + childPolicy: &internalserviceconfig.BalancerConfig{Name: weightedroundrobin.Name}, + // lrsServer is nil, so LRS policy will not be used. + wantConfig: &weightedtarget.LBConfig{ + Targets: map[string]weightedtarget.Target{ + assertString(internal.LocalityID{Zone: "test-zone-1"}.ToString): { + Weight: 20, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + assertString(internal.LocalityID{Zone: "test-zone-2"}.ToString): { + Weight: 80, + ChildPolicy: &internalserviceconfig.BalancerConfig{ + Name: weightedroundrobin.Name, + }, + }, + }, + }, + wantAddrs: []resolver.Address{ + testAddrWithAttrs("addr-1-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-1-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-1"}), + testAddrWithAttrs("addr-2-1", newUint32(90), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + testAddrWithAttrs("addr-2-2", newUint32(10), "test-priority", &internal.LocalityID{Zone: "test-zone-2"}), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := localitiesToWeightedTarget(tt.localities, tt.priorityName, tt.childPolicy, tt.cluster, tt.edsService) + if diff := cmp.Diff(got, tt.wantConfig); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + if diff := cmp.Diff(got1, tt.wantAddrs, cmp.AllowUnexported(attributes.Attributes{})); diff != "" { + t.Errorf("localitiesToWeightedTarget() diff (-got +want) %v", diff) + } + }) + } +} + +func assertString(f func() (string, error)) string { + s, err := f() + if err != nil { + panic(err.Error()) + } + return s +} + +func testAddrWithAttrs(addrStr string, weight *uint32, priority string, lID *internal.LocalityID) resolver.Address { + addr := resolver.Address{Addr: addrStr} + if weight != nil { + addr = weightedroundrobin.SetAddrInfo(addr, weightedroundrobin.AddrInfo{Weight: *weight}) + } + path := []string{priority} + if lID != nil { + path = append(path, assertString(lID.ToString)) + addr = internal.SetLocalityID(addr, *lID) + } + addr = hierarchy.Set(addr, path) + return addr +} diff --git a/xds/internal/balancer/clusterresolver/eds_impl_test.go b/xds/internal/balancer/clusterresolver/eds_impl_test.go new file mode 100644 index 000000000000..f565c5249870 --- /dev/null +++ b/xds/internal/balancer/clusterresolver/eds_impl_test.go @@ -0,0 +1,743 @@ +// +build go1.12 + +/* + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package clusterresolver + +import ( + "context" + "fmt" + "sort" + "testing" + "time" + + corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + "github.com/google/go-cmp/cmp" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/roundrobin" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/internal/balancer/stub" + internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/balancergroup" + "google.golang.org/grpc/xds/internal/balancer/clusterimpl" + "google.golang.org/grpc/xds/internal/balancer/priority" + "google.golang.org/grpc/xds/internal/balancer/weightedtarget" + "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" +) + +var ( + testClusterNames = []string{"test-cluster-1", "test-cluster-2"} + testSubZones = []string{"I", "II", "III", "IV"} + testEndpointAddrs []string +) + +const testBackendAddrsCount = 12 + +func init() { + for i := 0; i < testBackendAddrsCount; i++ { + testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) + } + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond + clusterimpl.NewRandomWRR = testutils.NewTestWRR + weightedtarget.NewRandomWRR = testutils.NewTestWRR + balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond * 100 +} + +func setupTestEDS(t *testing.T, initChild *internalserviceconfig.BalancerConfig) (balancer.Balancer, *testutils.TestClientConn, *fakeclient.Client, func()) { + xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) + cc := testutils.NewTestClientConn(t) + builder := balancer.Get(Name) + edsb := builder.Build(cc, balancer.BuildOptions{Target: resolver.Target{Endpoint: testEDSServcie}}) + if edsb == nil { + t.Fatalf("builder.Build(%s) failed and returned nil", Name) + } + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + ResolverState: xdsclient.SetClient(resolver.State{}, xdsC), + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: initChild, + }, + }); err != nil { + edsb.Close() + xdsC.Close() + t.Fatal(err) + } + if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { + edsb.Close() + xdsC.Close() + t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) + } + return edsb, cc, xdsC, func() { + edsb.Close() + xdsC.Close() + } +} + +// One locality +// - add backend +// - remove backend +// - replace backend +// - change drop rate +func (s) TestEDS_OneLocality(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + // The same locality, add one more backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, delete first backend. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the second subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } + + // The same locality, replace backend. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + scToRemove = <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with only the third subconn. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } + + // The same locality, different drop rate, dropping 50%. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) + clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Picks with drops. + if err := testPickerFromCh(cc.NewPickerCh, func(p balancer.Picker) error { + for i := 0; i < 100; i++ { + _, err := p.Pick(balancer.PickInfo{}) + // TODO: the dropping algorithm needs a design. When the dropping algorithm + // is fixed, this test also needs fix. + if i%2 == 0 && err == nil { + return fmt.Errorf("%d - the even number picks should be drops, got error ", i) + } else if i%2 != 0 && err != nil { + return fmt.Errorf("%d - the odd number picks should be non-drops, got error %v", i, err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } + + // The same locality, remove drops. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Pick without drops. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3}); err != nil { + t.Fatal(err) + } +} + +// 2 locality +// - start with 2 locality +// - add locality +// - remove locality +// - address change for the locality +// - update locality weight +func (s) TestEDS_TwoLocalities(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Add the second locality later to make sure sc2 belongs to the second + // locality. Otherwise the test is flaky because of a map is used in EDS to + // keep localities. + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with two subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + // Add another locality, with one backend. + clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test roundrobin with three subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Remove first locality. + clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) + + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Test pick with two subconns (without the first one). + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc3}); err != nil { + t.Fatal(err) + } + + // Add a backend to the last locality. + clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab4.Build()), nil) + + sc4 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Test pick with two subconns (without the first one). + // + // Locality-1 will be picked twice, and locality-2 will be picked twice. + // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect + // two sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1]. + clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) + clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab5.Build()), nil) + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be picked four times, and locality-2 will be picked twice + // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and + // sc4. So expect four sc2's and sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4}); err != nil { + t.Fatal(err) + } + + // Change weight of the locality[1] to 0, it should never be picked. + clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) + clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab6.Build()), nil) + + // Changing weight of locality[1] to 0 caused it to be removed. It's subconn + // should also be removed. + // + // NOTE: this is because we handle locality with weight 0 same as the + // locality doesn't exist. If this changes in the future, this removeSubConn + // behavior will also change. + scToRemove2 := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) + } + + // Test pick with two subconns different locality weight. + // + // Locality-1 will be not be picked, and locality-2 will be picked. + // Locality-2 contains sc3 and sc4. So expect sc3, sc4. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) + } +} + +// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only +// healthy ones are used. +func (s) TestEDS_EndpointsHealth(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ + Health: []corepb.HealthStatus{ + corepb.HealthStatus_HEALTHY, + corepb.HealthStatus_UNHEALTHY, + corepb.HealthStatus_UNKNOWN, + corepb.HealthStatus_DRAINING, + corepb.HealthStatus_TIMEOUT, + corepb.HealthStatus_DEGRADED, + }, + }) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + var ( + readySCs []balancer.SubConn + newSubConnAddrStrs []string + ) + for i := 0; i < 4; i++ { + addr := <-cc.NewSubConnAddrsCh + newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) + sc := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + readySCs = append(readySCs, sc) + } + + wantNewSubConnAddrStrs := []string{ + testEndpointAddrs[0], + testEndpointAddrs[2], + testEndpointAddrs[6], + testEndpointAddrs[8], + } + sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { + out := append([]string(nil), in...) // Copy input to avoid mutating it. + sort.Strings(out) + return out + }) + if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { + t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) + } + + // There should be exactly 4 new SubConns. Check to make sure there's no + // more subconns being created. + select { + case <-cc.NewSubConnCh: + t.Fatalf("Got unexpected new subconn") + case <-time.After(time.Microsecond * 100): + } + + // Test roundrobin with the subconns. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, readySCs); err != nil { + t.Fatal(err) + } +} + +// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. +// +// It should send an error picker with transient failure to the parent. +func (s) TestEDS_EmptyUpdate(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + const cacheTimeout = 100 * time.Microsecond + oldCacheTimeout := balancergroup.DefaultSubBalancerCloseTimeout + balancergroup.DefaultSubBalancerCloseTimeout = cacheTimeout + defer func() { balancergroup.DefaultSubBalancerCloseTimeout = oldCacheTimeout }() + + // The first update is an empty update. + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) + } + + xdsC.InvokeWatchEDSCallback("", xdsclient.EndpointsUpdate{}, nil) + // Pick should fail with transient failure, and all priority removed error. + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) + } + + // Wait for the old SubConn to be removed (which happens when the child + // policy is closed), so a new update would trigger a new SubConn (we need + // this new SubConn to tell if the next picker is newly created). + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + + // Handle another update with priorities and localities. + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Pick with only the first backend. + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) + } +} + +// Create XDS balancer, and update sub-balancer before handling eds responses. +// Then switch between round-robin and a test stub-balancer after handling first +// eds response. +func (s) TestEDS_UpdateSubBalancerName(t *testing.T) { + const balancerName = "stubBalancer-TestEDS_UpdateSubBalancerName" + + stub.Register(balancerName, stub.BalancerFuncs{ + UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { + m, _ := bd.Data.(map[string]bool) + if m == nil { + m = make(map[string]bool) + bd.Data = m + } + for _, addr := range s.ResolverState.Addresses { + if !m[addr.Addr] { + m[addr.Addr] = true + bd.ClientConn.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{}) + } + } + return nil + }, + UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { + bd.ClientConn.UpdateState(balancer.State{ + ConnectivityState: state.ConnectivityState, + Picker: &testutils.TestConstPicker{Err: testutils.ErrTestConstPicker}, + }) + }, + }) + + t.Logf("initialize with sub-balancer: stub-balancer") + edsb, cc, xdsC, cleanup := setupTestEDS(t, &internalserviceconfig.BalancerConfig{Name: balancerName}) + defer cleanup() + + t.Logf("update sub-balancer to stub-balancer") + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: balancerName, + }, + }, + }); err != nil { + t.Fatal(err) + } + + // Two localities, each with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + + for i := 0; i < 2; i++ { + sc := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + } + + if err := testErrPickerFromCh(cc.NewPickerCh, testutils.ErrTestConstPicker); err != nil { + t.Fatal(err) + } + + t.Logf("update sub-balancer to round-robin") + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + <-cc.RemoveSubConnCh + } + + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + sc2 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1, sc2}); err != nil { + t.Fatal(err) + } + + t.Logf("update sub-balancer to stub-balancer") + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: balancerName, + }, + }, + }); err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + scToRemove := <-cc.RemoveSubConnCh + if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && + !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { + t.Fatalf("RemoveSubConn, want (%v or %v), got %v", sc1, sc2, scToRemove) + } + edsb.UpdateSubConnState(scToRemove, balancer.SubConnState{ConnectivityState: connectivity.Shutdown}) + } + + for i := 0; i < 2; i++ { + sc := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + } + + if err := testErrPickerFromCh(cc.NewPickerCh, testutils.ErrTestConstPicker); err != nil { + t.Fatal(err) + } + + t.Logf("update sub-balancer to round-robin") + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatal(err) + } + + for i := 0; i < 2; i++ { + <-cc.RemoveSubConnCh + } + + sc3 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + sc4 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) + } +} + +func (s) TestEDS_CircuitBreaking(t *testing.T) { + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() + + var maxRequests uint32 = 50 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatal(err) + } + + // One locality with one backend. + clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) + clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) + sc1 := <-cc.NewSubConnCh + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + + // Picks with drops. + dones := []func(){} + p := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if i < 50 && err != nil { + t.Errorf("The first 50%% picks should be non-drops, got error %v", err) + } else if i > 50 && err == nil { + t.Errorf("The second 50%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 50; i++ { + pr, err := p.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The third 50%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } + + // Send another update, with only circuit breaking update (and no picker + // update afterwards). Make sure the new picker uses the new configs. + var maxRequests2 uint32 = 10 + if err := edsb.UpdateClientConnState(balancer.ClientConnState{ + BalancerConfig: &LBConfig{ + DiscoveryMechanisms: []DiscoveryMechanism{{ + Cluster: testClusterName, + MaxConcurrentRequests: &maxRequests2, + Type: DiscoveryMechanismTypeEDS, + }}, + EndpointPickingPolicy: &internalserviceconfig.BalancerConfig{ + Name: roundrobin.Name, + }, + }, + }); err != nil { + t.Fatal(err) + } + + // Picks with drops. + dones = []func(){} + p2 := <-cc.NewPickerCh + for i := 0; i < 100; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if i < 10 && err != nil { + t.Errorf("The first 10%% picks should be non-drops, got error %v", err) + } else if i > 10 && err == nil { + t.Errorf("The next 90%% picks should be drops, got error ") + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + for _, done := range dones { + done() + } + dones = []func(){} + + // Pick without drops. + for i := 0; i < 10; i++ { + pr, err := p2.Pick(balancer.PickInfo{}) + if err != nil { + t.Errorf("The next 10%% picks should be non-drops, got error %v", err) + } + dones = append(dones, func() { + if pr.Done != nil { + pr.Done(balancer.DoneInfo{}) + } + }) + } + + // Without this, future tests with the same service name will fail. + for _, done := range dones { + done() + } +} diff --git a/xds/internal/balancer/edsbalancer/logging.go b/xds/internal/balancer/clusterresolver/logging.go similarity index 84% rename from xds/internal/balancer/edsbalancer/logging.go rename to xds/internal/balancer/clusterresolver/logging.go index be4d0a512d16..728f1f709c28 100644 --- a/xds/internal/balancer/edsbalancer/logging.go +++ b/xds/internal/balancer/clusterresolver/logging.go @@ -16,7 +16,7 @@ * */ -package edsbalancer +package clusterresolver import ( "fmt" @@ -25,10 +25,10 @@ import ( internalgrpclog "google.golang.org/grpc/internal/grpclog" ) -const prefix = "[eds-lb %p] " +const prefix = "[xds-cluster-resolver-lb %p] " var logger = grpclog.Component("xds") -func prefixLogger(p *edsBalancer) *internalgrpclog.PrefixLogger { +func prefixLogger(p *clusterResolverBalancer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) } diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go b/xds/internal/balancer/clusterresolver/priority_test.go similarity index 54% rename from xds/internal/balancer/edsbalancer/eds_impl_priority_test.go rename to xds/internal/balancer/clusterresolver/priority_test.go index 51b35f22f09d..2e0e9fd5d2ff 100644 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority_test.go +++ b/xds/internal/balancer/clusterresolver/priority_test.go @@ -17,7 +17,7 @@ * limitations under the License. */ -package edsbalancer +package clusterresolver import ( "context" @@ -28,6 +28,8 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/balancer/priority" "google.golang.org/grpc/xds/internal/testutils" ) @@ -36,15 +38,14 @@ import ( // // Init 0 and 1; 0 is up, use 0; add 2, use 0; remove 2, use 0. func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want { @@ -53,22 +54,20 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { sc1 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -84,7 +83,7 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) select { case <-cc.NewPickerCh: @@ -102,15 +101,14 @@ func (s) TestEDSPriority_HighPriorityReady(t *testing.T) { // Init 0 and 1; 0 is up, use 0; 0 is down, 1 is up, use 1; add 2, use 1; 1 is // down, use 2; remove 2, use 1. func (s) TestEDSPriority_SwitchPriority(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0, 1], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { @@ -119,41 +117,35 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { sc0 := <-cc.NewSubConnCh // p0 is ready. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } - // Add p2, it shouldn't cause any udpates. + // Add p2, it shouldn't cause any updates. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) select { case <-cc.NewPickerCh: @@ -166,29 +158,25 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Turn down 1, use 2 - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // Remove 2, use 1. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab3.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p2 SubConns are removed. scToRemove := <-cc.RemoveSubConnCh @@ -197,28 +185,23 @@ func (s) TestEDSPriority_SwitchPriority(t *testing.T) { } // Should get an update with 1's old picker, to override 2's old picker. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := p3.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } + } // Add a lower priority while the higher priority is down. // // Init 0 and 1; 0 and 1 both down; add 2, use 2. func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -226,21 +209,18 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, pick should error. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) // Test pick failure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != balancer.ErrTransientFailure { - t.Fatalf("want pick error %v, got %v", balancer.ErrTransientFailure, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrTransientFailure); err != nil { + t.Fatal(err) } // Add p2, it should create a new SubConn. @@ -248,41 +228,34 @@ func (s) TestEDSPriority_HigherDownWhileAddingLower(t *testing.T) { clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab2.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } + } // When a higher priority becomes available, all lower priorities are closed. // // Init 0,1,2; 0 and 1 down, use 2; 0 up, close 1 and 2. func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with priorities [0,1,2], each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 2, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -290,39 +263,55 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { sc0 := <-cc.NewSubConnCh // Turn down 0, 1 is used. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh // Turn down 1, 2 is used. - edsb.handleSubConnStateChange(sc1, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 2. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc2}); err != nil { + t.Fatal(err) } // When 0 becomes ready, 0 should be used, 1 and 2 should all be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) + var ( + scToRemove []balancer.SubConn + scToRemoveMap = make(map[balancer.SubConn]struct{}) + ) + // Each subconn is removed twice. This is OK in production, but it makes + // testing harder. + // + // The sub-balancer to be closed is priority's child, clusterimpl, who has + // weightedtarget as children. + // + // - When clusterimpl is removed from priority's balancergroup, all its + // subconns are removed once. + // - When clusterimpl is closed, it closes weightedtarget, and this + // weightedtarget's balancer removes all the same subconns again. + for i := 0; i < 4; i++ { + // We expect 2 subconns, so we recv from channel 4 times. + scToRemoveMap[<-cc.RemoveSubConnCh] = struct{}{} + } + for sc := range scToRemoveMap { + scToRemove = append(scToRemove, sc) + } // sc1 and sc2 should be removed. // // With localities caching, the lower priorities are closed after a timeout, // in goroutines. The order is no longer guaranteed. - scToRemove := []balancer.SubConn{<-cc.RemoveSubConnCh, <-cc.RemoveSubConnCh} if !(cmp.Equal(scToRemove[0], sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && cmp.Equal(scToRemove[1], sc2, cmp.AllowUnexported(testutils.TestSubConn{}))) && !(cmp.Equal(scToRemove[0], sc2, cmp.AllowUnexported(testutils.TestSubConn{})) && @@ -331,12 +320,8 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { } // Test pick with 0. - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p0.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc0) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } } @@ -347,23 +332,20 @@ func (s) TestEDSPriority_HigherReadyCloseAllLower(t *testing.T) { func (s) TestEDSPriority_InitTimeout(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -371,7 +353,7 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { sc0 := <-cc.NewSubConnCh // Keep 0 in connecting, 1 will be used after init timeout. - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) // Make sure new SubConn is created before timeout. select { @@ -386,16 +368,12 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test pick with 1. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } } @@ -404,51 +382,44 @@ func (s) TestEDSPriority_InitTimeout(t *testing.T) { // - start with 2 locality with p0 and p1 // - add localities to existing p0 and p1 func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) addrs1 := <-cc.NewSubConnAddrsCh if got, want := addrs1[0].Addr, testEndpointAddrs[1]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc1}); err != nil { + t.Fatal(err) } // Reconnect p0 subconns, p1 subconn will be closed. - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) scToRemove := <-cc.RemoveSubConnCh if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { @@ -456,10 +427,8 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { } // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Add two localities, with two priorities, with one backend. @@ -468,39 +437,34 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { clab1.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) clab1.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) clab1.AddLocality(testSubZones[3], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) addrs2 := <-cc.NewSubConnAddrsCh if got, want := addrs2[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only two p0 subconns. - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc0, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0, sc2}); err != nil { + t.Fatal(err) } // Turn down p0 subconns, p1 subconns will be created. - edsb.handleSubConnStateChange(sc0, connectivity.TransientFailure) - edsb.handleSubConnStateChange(sc2, connectivity.TransientFailure) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) + edsb.UpdateSubConnState(sc2, balancer.SubConnState{ConnectivityState: connectivity.TransientFailure}) sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc3, balancer.SubConnState{ConnectivityState: connectivity.Ready}) sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc4, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p4 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc3, sc4}); err != nil { + t.Fatal(err) } } @@ -508,62 +472,55 @@ func (s) TestEDSPriority_MultipleLocalities(t *testing.T) { func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { const testPriorityInitTimeout = time.Second defer func() func() { - old := defaultPriorityInitTimeout - defaultPriorityInitTimeout = testPriorityInitTimeout + old := priority.DefaultPriorityInitTimeout + priority.DefaultPriorityInitTimeout = testPriorityInitTimeout return func() { - defaultPriorityInitTimeout = old + priority.DefaultPriorityInitTimeout = old } }()() - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - + edsb, cc, xdsC, cleanup := setupTestEDS(t, nil) + defer cleanup() // Two localities, with different priorities, each with one backend. clab0 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab0.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) clab0.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab0.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab0.Build()), nil) addrs0 := <-cc.NewSubConnAddrsCh if got, want := addrs0[0].Addr, testEndpointAddrs[0]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc0 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc0, connectivity.Connecting) - edsb.handleSubConnStateChange(sc0, connectivity.Ready) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc0, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p0 := <-cc.NewPickerCh - want := []balancer.SubConn{sc0} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p0)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc0}); err != nil { + t.Fatal(err) } // Remove all priorities. clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab1.Build()), nil) // p0 subconn should be removed. scToRemove := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove, sc0, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc0, scToRemove) } + // time.Sleep(time.Second) + // Test pick return TransientFailure. - pFail := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, priority.ErrAllPrioritiesRemoved); err != nil { + t.Fatal(err) } // Re-add two localities, with previous priorities, but different backends. clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) clab2.AddLocality(testSubZones[1], 1, 1, testEndpointAddrs[3:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab2.Build()), nil) addrs01 := <-cc.NewSubConnAddrsCh if got, want := addrs01[0].Addr, testEndpointAddrs[2]; got != want { t.Fatalf("sc is created with addr %v, want %v", got, want) @@ -580,45 +537,39 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { t.Fatalf("sc is created with addr %v, want %v", got, want) } sc11 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc11, connectivity.Connecting) - edsb.handleSubConnStateChange(sc11, connectivity.Ready) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc11, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p1 subconns. - p1 := <-cc.NewPickerCh - want = []balancer.SubConn{sc11} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc11}); err != nil { + t.Fatal(err) } // Remove p1 from EDS, to fallback to p0. clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) + xdsC.InvokeWatchEDSCallback("", parseEDSRespProtoForTesting(clab3.Build()), nil) // p1 subconn should be removed. scToRemove1 := <-cc.RemoveSubConnCh + <-cc.RemoveSubConnCh // Drain the duplicate subconn removed. if !cmp.Equal(scToRemove1, sc11, cmp.AllowUnexported(testutils.TestSubConn{})) { t.Fatalf("RemoveSubConn, want %v, got %v", sc11, scToRemove1) } // Test pick return TransientFailure. - pFail1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - if scst, err := pFail1.Pick(balancer.PickInfo{}); err != balancer.ErrNoSubConnAvailable { - t.Fatalf("want pick error _, %v, got %v, _ ,%v", balancer.ErrTransientFailure, scst, err) - } + if err := testErrPickerFromCh(cc.NewPickerCh, balancer.ErrNoSubConnAvailable); err != nil { + t.Fatal(err) } // Send an ready update for the p0 sc that was received when re-adding // localities to EDS. - edsb.handleSubConnStateChange(sc01, connectivity.Connecting) - edsb.handleSubConnStateChange(sc01, connectivity.Ready) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) + edsb.UpdateSubConnState(sc01, balancer.SubConnState{ConnectivityState: connectivity.Ready}) // Test roundrobin with only p0 subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc01} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) + if err := testRoundRobinPickerFromCh(cc.NewPickerCh, []balancer.SubConn{sc01}); err != nil { + t.Fatal(err) } select { @@ -632,83 +583,16 @@ func (s) TestEDSPriority_RemovesAllLocalities(t *testing.T) { } } -func (s) TestPriorityType(t *testing.T) { - p0 := newPriorityType(0) - p1 := newPriorityType(1) - p2 := newPriorityType(2) - - if !p0.higherThan(p1) || !p0.higherThan(p2) { - t.Errorf("want p0 to be higher than p1 and p2, got p0>p1: %v, p0>p2: %v", !p0.higherThan(p1), !p0.higherThan(p2)) - } - if !p1.lowerThan(p0) || !p1.higherThan(p2) { - t.Errorf("want p1 to be between p0 and p2, got p1p2: %v", !p1.lowerThan(p0), !p1.higherThan(p2)) - } - if !p2.lowerThan(p0) || !p2.lowerThan(p1) { - t.Errorf("want p2 to be lower than p0 and p1, got p2 subConnMu, but this is implicit via - // balancers (starting balancer with next priority while holding priorityMu, - // and the balancer may create new SubConn). - - priorityMu sync.Mutex - // priorities are pointers, and will be nil when EDS returns empty result. - priorityInUse priorityType - priorityLowest priorityType - priorityToState map[priorityType]*balancer.State - // The timer to give a priority 10 seconds to connect. And if the priority - // doesn't go into Ready/Failure, start the next priority. - // - // One timer is enough because there can be at most one priority in init - // state. - priorityInitTimer *time.Timer - - subConnMu sync.Mutex - subConnToPriority map[balancer.SubConn]priorityType - - pickerMu sync.Mutex - dropConfig []xdsclient.OverloadDropConfig - drops []*dropper - innerState balancer.State // The state of the picker without drop support. - serviceRequestsCounter *client.ServiceRequestsCounter - serviceRequestCountMax uint32 - - clusterNameMu sync.Mutex - clusterName string -} - -// newEDSBalancerImpl create a new edsBalancerImpl. -func newEDSBalancerImpl(cc balancer.ClientConn, bOpts balancer.BuildOptions, enqueueState func(priorityType, balancer.State), lr load.PerClusterReporter, logger *grpclog.PrefixLogger) *edsBalancerImpl { - edsImpl := &edsBalancerImpl{ - cc: cc, - buildOpts: bOpts, - logger: logger, - subBalancerBuilder: balancer.Get(roundrobin.Name), - loadReporter: lr, - - enqueueChildBalancerStateUpdate: enqueueState, - - priorityToLocalities: make(map[priorityType]*balancerGroupWithConfig), - priorityToState: make(map[priorityType]*balancer.State), - subConnToPriority: make(map[balancer.SubConn]priorityType), - serviceRequestCountMax: defaultServiceRequestCountMax, - } - // Don't start balancer group here. Start it when handling the first EDS - // response. Otherwise the balancer group will be started with round-robin, - // and if users specify a different sub-balancer, all balancers in balancer - // group will be closed and recreated when sub-balancer update happens. - return edsImpl -} - -// handleChildPolicy updates the child balancers handling endpoints. Child -// policy is roundrobin by default. If the specified balancer is not installed, -// the old child balancer will be used. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleChildPolicy(name string, config json.RawMessage) { - if edsImpl.subBalancerBuilder.Name() == name { - return - } - newSubBalancerBuilder := balancer.Get(name) - if newSubBalancerBuilder == nil { - edsImpl.logger.Infof("edsBalancerImpl: failed to find balancer with name %q, keep using %q", name, edsImpl.subBalancerBuilder.Name()) - return - } - edsImpl.subBalancerBuilder = newSubBalancerBuilder - for _, bgwc := range edsImpl.priorityToLocalities { - if bgwc == nil { - continue - } - for lid, config := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - // TODO: (eds) add support to balancer group to support smoothly - // switching sub-balancers (keep old balancer around until new - // balancer becomes ready). - bgwc.bg.Remove(lidJSON) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: config.addrs}, - }) - // This doesn't need to manually update picker, because the new - // sub-balancer will send it's picker later. - } - } -} - -// updateDrops compares new drop policies with the old. If they are different, -// it updates the drop policies and send ClientConn an updated picker. -func (edsImpl *edsBalancerImpl) updateDrops(dropConfig []xdsclient.OverloadDropConfig) { - if cmp.Equal(dropConfig, edsImpl.dropConfig) { - return - } - edsImpl.pickerMu.Lock() - edsImpl.dropConfig = dropConfig - var newDrops []*dropper - for _, c := range edsImpl.dropConfig { - newDrops = append(newDrops, newDropper(c)) - } - edsImpl.drops = newDrops - if edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new drops. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, newDrops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -// handleEDSResponse handles the EDS response and creates/deletes localities and -// SubConns. It also handles drops. -// -// HandleChildPolicy and HandleEDSResponse must be called by the same goroutine. -func (edsImpl *edsBalancerImpl) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - // TODO: Unhandled fields from EDS response: - // - edsResp.GetPolicy().GetOverprovisioningFactor() - // - locality.GetPriority() - // - lbEndpoint.GetMetadata(): contains BNS name, send to sub-balancers - // - as service config or as resolved address - // - if socketAddress is not ip:port - // - socketAddress.GetNamedPort(), socketAddress.GetResolverName() - // - resolve endpoint's name with another resolver - - // If the first EDS update is an empty update, nothing is changing from the - // previous update (which is the default empty value). We need to explicitly - // handle first update being empty, and send a transient failure picker. - // - // TODO: define Equal() on type EndpointUpdate to avoid DeepEqual. And do - // the same for the other types. - if !edsImpl.respReceived && reflect.DeepEqual(edsResp, xdsclient.EndpointsUpdate{}) { - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - } - edsImpl.respReceived = true - - edsImpl.updateDrops(edsResp.Drops) - - // Filter out all localities with weight 0. - // - // Locality weighted load balancer can be enabled by setting an option in - // CDS, and the weight of each locality. Currently, without the guarantee - // that CDS is always sent, we assume locality weighted load balance is - // always enabled, and ignore all weight 0 localities. - // - // In the future, we should look at the config in CDS response and decide - // whether locality weight matters. - newLocalitiesWithPriority := make(map[priorityType][]xdsclient.Locality) - for _, locality := range edsResp.Localities { - if locality.Weight == 0 { - continue - } - priority := newPriorityType(locality.Priority) - newLocalitiesWithPriority[priority] = append(newLocalitiesWithPriority[priority], locality) - } - - var ( - priorityLowest priorityType - priorityChanged bool - ) - - for priority, newLocalities := range newLocalitiesWithPriority { - if !priorityLowest.isSet() || priorityLowest.higherThan(priority) { - priorityLowest = priority - } - - bgwc, ok := edsImpl.priorityToLocalities[priority] - if !ok { - // Create balancer group if it's never created (this is the first - // time this priority is received). We don't start it here. It may - // be started when necessary (e.g. when higher is down, or if it's a - // new lowest priority). - ccPriorityWrapper := edsImpl.ccWrapperWithPriority(priority) - stateAggregator := weightedaggregator.New(ccPriorityWrapper, edsImpl.logger, newRandomWRR) - bgwc = &balancerGroupWithConfig{ - bg: balancergroup.New(ccPriorityWrapper, edsImpl.buildOpts, stateAggregator, edsImpl.loadReporter, edsImpl.logger), - stateAggregator: stateAggregator, - configs: make(map[xdsi.LocalityID]*localityConfig), - } - edsImpl.priorityToLocalities[priority] = bgwc - priorityChanged = true - edsImpl.logger.Infof("New priority %v added", priority) - } - edsImpl.handleEDSResponsePerPriority(bgwc, newLocalities) - } - edsImpl.priorityLowest = priorityLowest - - // Delete priorities that are removed in the latest response, and also close - // the balancer group. - for p, bgwc := range edsImpl.priorityToLocalities { - if _, ok := newLocalitiesWithPriority[p]; !ok { - delete(edsImpl.priorityToLocalities, p) - bgwc.bg.Close() - delete(edsImpl.priorityToState, p) - priorityChanged = true - edsImpl.logger.Infof("Priority %v deleted", p) - } - } - - // If priority was added/removed, it may affect the balancer group to use. - // E.g. priorityInUse was removed, or all priorities are down, and a new - // lower priority was added. - if priorityChanged { - edsImpl.handlePriorityChange() - } -} - -func (edsImpl *edsBalancerImpl) handleEDSResponsePerPriority(bgwc *balancerGroupWithConfig, newLocalities []xdsclient.Locality) { - // newLocalitiesSet contains all names of localities in the new EDS response - // for the same priority. It's used to delete localities that are removed in - // the new EDS response. - newLocalitiesSet := make(map[xdsi.LocalityID]struct{}) - var rebuildStateAndPicker bool - for _, locality := range newLocalities { - // One balancer for each locality. - - lid := locality.ID - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - newLocalitiesSet[lid] = struct{}{} - - newWeight := locality.Weight - var newAddrs []resolver.Address - for _, lbEndpoint := range locality.Endpoints { - // Filter out all "unhealthy" endpoints (unknown and - // healthy are both considered to be healthy: - // https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/core/health_check.proto#envoy-api-enum-core-healthstatus). - if lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusHealthy && - lbEndpoint.HealthStatus != xdsclient.EndpointHealthStatusUnknown { - continue - } - - address := resolver.Address{ - Addr: lbEndpoint.Address, - } - if edsImpl.subBalancerBuilder.Name() == weightedroundrobin.Name && lbEndpoint.Weight != 0 { - ai := weightedroundrobin.AddrInfo{Weight: lbEndpoint.Weight} - address = weightedroundrobin.SetAddrInfo(address, ai) - // Metadata field in resolver.Address is deprecated. The - // attributes field should be used to specify arbitrary - // attributes about the address. We still need to populate the - // Metadata field here to allow users of this field to migrate - // to the new one. - // TODO(easwars): Remove this once all users have migrated. - // See https://github.com/grpc/grpc-go/issues/3563. - address.Metadata = &ai - } - newAddrs = append(newAddrs, address) - } - var weightChanged, addrsChanged bool - config, ok := bgwc.configs[lid] - if !ok { - // A new balancer, add it to balancer group and balancer map. - bgwc.stateAggregator.Add(lidJSON, newWeight) - bgwc.bg.Add(lidJSON, edsImpl.subBalancerBuilder) - config = &localityConfig{ - weight: newWeight, - } - bgwc.configs[lid] = config - - // weightChanged is false for new locality, because there's no need - // to update weight in bg. - addrsChanged = true - edsImpl.logger.Infof("New locality %v added", lid) - } else { - // Compare weight and addrs. - if config.weight != newWeight { - weightChanged = true - } - if !cmp.Equal(config.addrs, newAddrs) { - addrsChanged = true - } - edsImpl.logger.Infof("Locality %v updated, weightedChanged: %v, addrsChanged: %v", lid, weightChanged, addrsChanged) - } - - if weightChanged { - config.weight = newWeight - bgwc.stateAggregator.UpdateWeight(lidJSON, newWeight) - rebuildStateAndPicker = true - } - - if addrsChanged { - config.addrs = newAddrs - bgwc.bg.UpdateClientConnState(lidJSON, balancer.ClientConnState{ - ResolverState: resolver.State{Addresses: newAddrs}, - }) - } - } - - // Delete localities that are removed in the latest response. - for lid := range bgwc.configs { - lidJSON, err := lid.ToString() - if err != nil { - edsImpl.logger.Errorf("failed to marshal LocalityID: %#v, skipping this locality", lid) - continue - } - if _, ok := newLocalitiesSet[lid]; !ok { - bgwc.stateAggregator.Remove(lidJSON) - bgwc.bg.Remove(lidJSON) - delete(bgwc.configs, lid) - edsImpl.logger.Infof("Locality %v deleted", lid) - rebuildStateAndPicker = true - } - } - - if rebuildStateAndPicker { - bgwc.stateAggregator.BuildAndUpdate() - } -} - -// handleSubConnStateChange handles the state change and update pickers accordingly. -func (edsImpl *edsBalancerImpl) handleSubConnStateChange(sc balancer.SubConn, s connectivity.State) { - edsImpl.subConnMu.Lock() - var bgwc *balancerGroupWithConfig - if p, ok := edsImpl.subConnToPriority[sc]; ok { - if s == connectivity.Shutdown { - // Only delete sc from the map when state changed to Shutdown. - delete(edsImpl.subConnToPriority, sc) - } - bgwc = edsImpl.priorityToLocalities[p] - } - edsImpl.subConnMu.Unlock() - if bgwc == nil { - edsImpl.logger.Infof("edsBalancerImpl: priority not found for sc state change") - return - } - if bg := bgwc.bg; bg != nil { - bg.UpdateSubConnState(sc, balancer.SubConnState{ConnectivityState: s}) - } -} - -// updateServiceRequestsConfig handles changes to the circuit breaking configuration. -func (edsImpl *edsBalancerImpl) updateServiceRequestsConfig(serviceName string, max *uint32) { - if !env.CircuitBreakingSupport { - return - } - edsImpl.pickerMu.Lock() - var updatePicker bool - if edsImpl.serviceRequestsCounter == nil || edsImpl.serviceRequestsCounter.ServiceName != serviceName { - edsImpl.serviceRequestsCounter = client.GetServiceRequestsCounter(serviceName) - updatePicker = true - } - - var newMax uint32 = defaultServiceRequestCountMax - if max != nil { - newMax = *max - } - if edsImpl.serviceRequestCountMax != newMax { - edsImpl.serviceRequestCountMax = newMax - updatePicker = true - } - if updatePicker && edsImpl.innerState.Picker != nil { - // Update picker with old inner picker, new counter and counterMax. - edsImpl.cc.UpdateState(balancer.State{ - ConnectivityState: edsImpl.innerState.ConnectivityState, - Picker: newDropPicker(edsImpl.innerState.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}, - ) - } - edsImpl.pickerMu.Unlock() -} - -func (edsImpl *edsBalancerImpl) updateClusterName(name string) { - edsImpl.clusterNameMu.Lock() - defer edsImpl.clusterNameMu.Unlock() - edsImpl.clusterName = name -} - -func (edsImpl *edsBalancerImpl) getClusterName() string { - edsImpl.clusterNameMu.Lock() - defer edsImpl.clusterNameMu.Unlock() - return edsImpl.clusterName -} - -// updateState first handles priority, and then wraps picker in a drop picker -// before forwarding the update. -func (edsImpl *edsBalancerImpl) updateState(priority priorityType, s balancer.State) { - _, ok := edsImpl.priorityToLocalities[priority] - if !ok { - edsImpl.logger.Infof("eds: received picker update from unknown priority") - return - } - - if edsImpl.handlePriorityWithNewState(priority, s) { - edsImpl.pickerMu.Lock() - defer edsImpl.pickerMu.Unlock() - edsImpl.innerState = s - // Don't reset drops when it's a state change. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: s.ConnectivityState, Picker: newDropPicker(s.Picker, edsImpl.drops, edsImpl.loadReporter, edsImpl.serviceRequestsCounter, edsImpl.serviceRequestCountMax)}) - } -} - -func (edsImpl *edsBalancerImpl) ccWrapperWithPriority(priority priorityType) *edsBalancerWrapperCC { - return &edsBalancerWrapperCC{ - ClientConn: edsImpl.cc, - priority: priority, - parent: edsImpl, - } -} - -// edsBalancerWrapperCC implements the balancer.ClientConn API and get passed to -// each balancer group. It contains the locality priority. -type edsBalancerWrapperCC struct { - balancer.ClientConn - priority priorityType - parent *edsBalancerImpl -} - -func (ebwcc *edsBalancerWrapperCC) NewSubConn(addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - clusterName := ebwcc.parent.getClusterName() - newAddrs := make([]resolver.Address, len(addrs)) - for i, addr := range addrs { - newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) - } - return ebwcc.parent.newSubConn(ebwcc.priority, newAddrs, opts) -} - -func (ebwcc *edsBalancerWrapperCC) UpdateAddresses(sc balancer.SubConn, addrs []resolver.Address) { - clusterName := ebwcc.parent.getClusterName() - newAddrs := make([]resolver.Address, len(addrs)) - for i, addr := range addrs { - newAddrs[i] = internal.SetXDSHandshakeClusterName(addr, clusterName) - } - ebwcc.ClientConn.UpdateAddresses(sc, newAddrs) -} - -func (ebwcc *edsBalancerWrapperCC) UpdateState(state balancer.State) { - ebwcc.parent.enqueueChildBalancerStateUpdate(ebwcc.priority, state) -} - -func (edsImpl *edsBalancerImpl) newSubConn(priority priorityType, addrs []resolver.Address, opts balancer.NewSubConnOptions) (balancer.SubConn, error) { - sc, err := edsImpl.cc.NewSubConn(addrs, opts) - if err != nil { - return nil, err - } - edsImpl.subConnMu.Lock() - edsImpl.subConnToPriority[sc] = priority - edsImpl.subConnMu.Unlock() - return sc, nil -} - -// close closes the balancer. -func (edsImpl *edsBalancerImpl) close() { - for _, bgwc := range edsImpl.priorityToLocalities { - if bg := bgwc.bg; bg != nil { - bgwc.stateAggregator.Stop() - bg.Close() - } - } -} - -type dropPicker struct { - drops []*dropper - p balancer.Picker - loadStore load.PerClusterReporter - counter *client.ServiceRequestsCounter - countMax uint32 -} - -func newDropPicker(p balancer.Picker, drops []*dropper, loadStore load.PerClusterReporter, counter *client.ServiceRequestsCounter, countMax uint32) *dropPicker { - return &dropPicker{ - drops: drops, - p: p, - loadStore: loadStore, - counter: counter, - countMax: countMax, - } -} - -func (d *dropPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - var ( - drop bool - category string - ) - for _, dp := range d.drops { - if dp.drop() { - drop = true - category = dp.c.Category - break - } - } - if drop { - if d.loadStore != nil { - d.loadStore.CallDropped(category) - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, "RPC is dropped") - } - if d.counter != nil { - if err := d.counter.StartRequest(d.countMax); err != nil { - // Drops by circuit breaking are reported with empty category. They - // will be reported only in total drops, but not in per category. - if d.loadStore != nil { - d.loadStore.CallDropped("") - } - return balancer.PickResult{}, status.Errorf(codes.Unavailable, err.Error()) - } - pr, err := d.p.Pick(info) - if err != nil { - d.counter.EndRequest() - return pr, err - } - oldDone := pr.Done - pr.Done = func(doneInfo balancer.DoneInfo) { - d.counter.EndRequest() - if oldDone != nil { - oldDone(doneInfo) - } - } - return pr, err - } - // TODO: (eds) don't drop unless the inner picker is READY. Similar to - // https://github.com/grpc/grpc-go/issues/2622. - return d.p.Pick(info) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_priority.go b/xds/internal/balancer/edsbalancer/eds_impl_priority.go deleted file mode 100644 index 53ac6ef5e873..000000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_priority.go +++ /dev/null @@ -1,358 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "errors" - "fmt" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/base" - "google.golang.org/grpc/connectivity" -) - -var errAllPrioritiesRemoved = errors.New("eds: no locality is provided, all priorities are removed") - -// handlePriorityChange handles priority after EDS adds/removes a -// priority. -// -// - If all priorities were deleted, unset priorityInUse, and set parent -// ClientConn to TransientFailure -// - If priorityInUse wasn't set, this is either the first EDS resp, or the -// previous EDS resp deleted everything. Set priorityInUse to 0, and start 0. -// - If priorityInUse was deleted, send the picker from the new lowest priority -// to parent ClientConn, and set priorityInUse to the new lowest. -// - If priorityInUse has a non-Ready state, and also there's a priority lower -// than priorityInUse (which means a lower priority was added), set the next -// priority as new priorityInUse, and start the bg. -func (edsImpl *edsBalancerImpl) handlePriorityChange() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - // Everything was removed by EDS. - if !edsImpl.priorityLowest.isSet() { - edsImpl.priorityInUse = newPriorityTypeUnset() - // Stop the init timer. This can happen if the only priority is removed - // shortly after it's added. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.TransientFailure, Picker: base.NewErrPicker(errAllPrioritiesRemoved)}) - return - } - - // priorityInUse wasn't set, use 0. - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("Switching priority from unset to %v", 0) - edsImpl.startPriority(newPriorityType(0)) - return - } - - // priorityInUse was deleted, use the new lowest. - if _, ok := edsImpl.priorityToLocalities[edsImpl.priorityInUse]; !ok { - oldP := edsImpl.priorityInUse - edsImpl.priorityInUse = edsImpl.priorityLowest - edsImpl.logger.Infof("Switching priority from %v to %v, because former was deleted", oldP, edsImpl.priorityInUse) - if s, ok := edsImpl.priorityToState[edsImpl.priorityLowest]; ok { - edsImpl.cc.UpdateState(*s) - } else { - // If state for priorityLowest is not found, this means priorityLowest was - // started, but never sent any update. The init timer fired and - // triggered the next priority. The old_priorityInUse (that was just - // deleted EDS) was picked later. - // - // We don't have an old state to send to parent, but we also don't - // want parent to keep using picker from old_priorityInUse. Send an - // update to trigger block picks until a new picker is ready. - edsImpl.cc.UpdateState(balancer.State{ConnectivityState: connectivity.Connecting, Picker: base.NewErrPicker(balancer.ErrNoSubConnAvailable)}) - } - return - } - - // priorityInUse is not ready, look for next priority, and use if found. - if s, ok := edsImpl.priorityToState[edsImpl.priorityInUse]; ok && s.ConnectivityState != connectivity.Ready { - pNext := edsImpl.priorityInUse.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter was added, and former wasn't Ready") - edsImpl.startPriority(pNext) - } - } -} - -// startPriority sets priorityInUse to p, and starts the balancer group for p. -// It also starts a timer to fall to next priority after timeout. -// -// Caller must hold priorityMu, priority must exist, and edsImpl.priorityInUse -// must be non-nil. -func (edsImpl *edsBalancerImpl) startPriority(priority priorityType) { - edsImpl.priorityInUse = priority - p := edsImpl.priorityToLocalities[priority] - // NOTE: this will eventually send addresses to sub-balancers. If the - // sub-balancer tries to update picker, it will result in a deadlock on - // priorityMu in the update is handled synchronously. The deadlock is - // currently avoided by handling balancer update in a goroutine (the run - // goroutine in the parent eds balancer). When priority balancer is split - // into its own, this asynchronous state handling needs to be copied. - p.stateAggregator.Start() - p.bg.Start() - // startPriority can be called when - // 1. first EDS resp, start p0 - // 2. a high priority goes Failure, start next - // 3. a high priority init timeout, start next - // - // In all the cases, the existing init timer is either closed, also already - // expired. There's no need to close the old timer. - edsImpl.priorityInitTimer = time.AfterFunc(defaultPriorityInitTimeout, func() { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - if !edsImpl.priorityInUse.isSet() || !edsImpl.priorityInUse.equal(priority) { - return - } - edsImpl.priorityInitTimer = nil - pNext := priority.nextLower() - if _, ok := edsImpl.priorityToLocalities[pNext]; ok { - edsImpl.startPriority(pNext) - } - }) -} - -// handlePriorityWithNewState start/close priorities based on the connectivity -// state. It returns whether the state should be forwarded to parent ClientConn. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewState(priority priorityType, s balancer.State) bool { - edsImpl.priorityMu.Lock() - defer edsImpl.priorityMu.Unlock() - - if !edsImpl.priorityInUse.isSet() { - edsImpl.logger.Infof("eds: received picker update when no priority is in use (EDS returned an empty list)") - return false - } - - if edsImpl.priorityInUse.higherThan(priority) { - // Lower priorities should all be closed, this is an unexpected update. - edsImpl.logger.Infof("eds: received picker update from priority lower then priorityInUse") - return false - } - - bState, ok := edsImpl.priorityToState[priority] - if !ok { - bState = &balancer.State{} - edsImpl.priorityToState[priority] = bState - } - oldState := bState.ConnectivityState - *bState = s - - switch s.ConnectivityState { - case connectivity.Ready: - return edsImpl.handlePriorityWithNewStateReady(priority) - case connectivity.TransientFailure: - return edsImpl.handlePriorityWithNewStateTransientFailure(priority) - case connectivity.Connecting: - return edsImpl.handlePriorityWithNewStateConnecting(priority, oldState) - default: - // New state is Idle, should never happen. Don't forward. - return false - } -} - -// handlePriorityWithNewStateReady handles state Ready and decides whether to -// forward update or not. -// -// An update with state Ready: -// - If it's from higher priority: -// - Forward the update -// - Set the priority as priorityInUse -// - Close all priorities lower than this one -// - If it's from priorityInUse: -// - Forward and do nothing else -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateReady(priority priorityType) bool { - // If one priority higher or equal to priorityInUse goes Ready, stop the - // init timer. If update is from higher than priorityInUse, - // priorityInUse will be closed, and the init timer will become useless. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - - if edsImpl.priorityInUse.lowerThan(priority) { - edsImpl.logger.Infof("Switching priority from %v to %v, because latter became Ready", edsImpl.priorityInUse, priority) - edsImpl.priorityInUse = priority - for i := priority.nextLower(); !i.lowerThan(edsImpl.priorityLowest); i = i.nextLower() { - bgwc := edsImpl.priorityToLocalities[i] - bgwc.stateAggregator.Stop() - bgwc.bg.Close() - } - return true - } - return true -} - -// handlePriorityWithNewStateTransientFailure handles state TransientFailure and -// decides whether to forward update or not. -// -// An update with state Failure: -// - If it's from a higher priority: -// - Do not forward, and do nothing -// - If it's from priorityInUse: -// - If there's no lower: -// - Forward and do nothing else -// - If there's a lower priority: -// - Forward -// - Set lower as priorityInUse -// - Start lower -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateTransientFailure(priority priorityType) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - // priorityInUse sends a failure. Stop its init timer. - if timer := edsImpl.priorityInitTimer; timer != nil { - timer.Stop() - edsImpl.priorityInitTimer = nil - } - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became TransientFailure", priority, pNext) - edsImpl.startPriority(pNext) - return true -} - -// handlePriorityWithNewStateConnecting handles state Connecting and decides -// whether to forward update or not. -// -// An update with state Connecting: -// - If it's from a higher priority -// - Do nothing -// - If it's from priorityInUse, the behavior depends on previous state. -// -// When new state is Connecting, the behavior depends on previous state. If the -// previous state was Ready, this is a transition out from Ready to Connecting. -// Assuming there are multiple backends in the same priority, this mean we are -// in a bad situation and we should failover to the next priority (Side note: -// the current connectivity state aggregating algorhtim (e.g. round-robin) is -// not handling this right, because if many backends all go from Ready to -// Connecting, the overall situation is more like TransientFailure, not -// Connecting). -// -// If the previous state was Idle, we don't do anything special with failure, -// and simply forward the update. The init timer should be in process, will -// handle failover if it timeouts. If the previous state was TransientFailure, -// we do not forward, because the lower priority is in use. -// -// Caller must make sure priorityInUse is not higher than priority. -// -// Caller must hold priorityMu. -func (edsImpl *edsBalancerImpl) handlePriorityWithNewStateConnecting(priority priorityType, oldState connectivity.State) bool { - if edsImpl.priorityInUse.lowerThan(priority) { - return false - } - - switch oldState { - case connectivity.Ready: - pNext := priority.nextLower() - if _, okNext := edsImpl.priorityToLocalities[pNext]; !okNext { - return true - } - edsImpl.logger.Infof("Switching priority from %v to %v, because former became Connecting from Ready", priority, pNext) - edsImpl.startPriority(pNext) - return true - case connectivity.Idle: - return true - case connectivity.TransientFailure: - return false - default: - // Old state is Connecting or Shutdown. Don't forward. - return false - } -} - -// priorityType represents the priority from EDS response. -// -// 0 is the highest priority. The bigger the number, the lower the priority. -type priorityType struct { - set bool - p uint32 -} - -func newPriorityType(p uint32) priorityType { - return priorityType{ - set: true, - p: p, - } -} - -func newPriorityTypeUnset() priorityType { - return priorityType{} -} - -func (p priorityType) isSet() bool { - return p.set -} - -func (p priorityType) equal(p2 priorityType) bool { - if !p.isSet() && !p2.isSet() { - return true - } - if !p.isSet() || !p2.isSet() { - return false - } - return p == p2 -} - -func (p priorityType) higherThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p < p2.p -} - -func (p priorityType) lowerThan(p2 priorityType) bool { - if !p.isSet() || !p2.isSet() { - // TODO(menghanl): return an appropriate value instead of panic. - panic("priority unset") - } - return p.p > p2.p -} - -func (p priorityType) nextLower() priorityType { - if !p.isSet() { - panic("priority unset") - } - return priorityType{ - set: true, - p: p.p + 1, - } -} - -func (p priorityType) String() string { - if !p.set { - return "Nil" - } - return fmt.Sprint(p.p) -} diff --git a/xds/internal/balancer/edsbalancer/eds_impl_test.go b/xds/internal/balancer/edsbalancer/eds_impl_test.go deleted file mode 100644 index 3cfc620a2400..000000000000 --- a/xds/internal/balancer/edsbalancer/eds_impl_test.go +++ /dev/null @@ -1,1009 +0,0 @@ -// +build go1.12 - -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "fmt" - "reflect" - "sort" - "testing" - "time" - - corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" - "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" - "google.golang.org/grpc/internal" - "google.golang.org/grpc/xds/internal/balancer/loadstore" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/balancer/stub" - "google.golang.org/grpc/internal/xds/env" - xdsinternal "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/balancer/balancergroup" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/testutils" -) - -var ( - testClusterNames = []string{"test-cluster-1", "test-cluster-2"} - testSubZones = []string{"I", "II", "III", "IV"} - testEndpointAddrs []string -) - -const testBackendAddrsCount = 12 - -func init() { - for i := 0; i < testBackendAddrsCount; i++ { - testEndpointAddrs = append(testEndpointAddrs, fmt.Sprintf("%d.%d.%d.%d:%d", i, i, i, i, i)) - } - balancergroup.DefaultSubBalancerCloseTimeout = time.Millisecond -} - -// One locality -// - add backend -// - remove backend -// - replace backend -// - change drop rate -func (s) TestEDS_OneLocality(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - // The same locality, add one more backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p2 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // The same locality, delete first backend. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the second subconn. - p3 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p3.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } - - // The same locality, replace backend. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - scToRemove = <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with only the third subconn. - p4 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p4.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } - - // The same locality, different drop rate, dropping 50%. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{"test-drop": 50}) - clab5.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Picks with drops. - p5 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - _, err := p5.Pick(balancer.PickInfo{}) - // TODO: the dropping algorithm needs a design. When the dropping algorithm - // is fixed, this test also needs fix. - if i < 50 && err == nil { - t.Errorf("The first 50%% picks should be drops, got error ") - } else if i > 50 && err != nil { - t.Errorf("The second 50%% picks should be non-drops, got error %v", err) - } - } - - // The same locality, remove drops. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Pick without drops. - p6 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p6.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc3, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc3) - } - } -} - -// 2 locality -// - start with 2 locality -// - add locality -// - remove locality -// - address change for the locality -// - update locality weight -func (s) TestEDS_TwoLocalities(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add another locality, with one backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab2.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab2.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - - // Test roundrobin with three subconns. - p2 := <-cc.NewPickerCh - want = []balancer.SubConn{sc1, sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p2)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Remove first locality. - clab3 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab3.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab3.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:3], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab3.Build())) - - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc1, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - - // Test pick with two subconns (without the first one). - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc2, sc3} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Add a backend to the last locality. - clab4 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab4.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - clab4.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab4.Build())) - - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - // Test pick with two subconns (without the first one). - p4 := <-cc.NewPickerCh - // Locality-1 will be picked twice, and locality-2 will be picked twice. - // Locality-1 contains only sc2, locality-2 contains sc3 and sc4. So expect - // two sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p4)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1]. - clab5 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab5.AddLocality(testSubZones[1], 2, 0, testEndpointAddrs[1:2], nil) - clab5.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab5.Build())) - - // Test pick with two subconns different locality weight. - p5 := <-cc.NewPickerCh - // Locality-1 will be picked four times, and locality-2 will be picked twice - // (weight 2 and 1). Locality-1 contains only sc2, locality-2 contains sc3 and - // sc4. So expect four sc2's and sc3, sc4. - want = []balancer.SubConn{sc2, sc2, sc2, sc2, sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p5)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - // Change weight of the locality[1] to 0, it should never be picked. - clab6 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab6.AddLocality(testSubZones[1], 0, 0, testEndpointAddrs[1:2], nil) - clab6.AddLocality(testSubZones[2], 1, 0, testEndpointAddrs[2:4], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab6.Build())) - - // Changing weight of locality[1] to 0 caused it to be removed. It's subconn - // should also be removed. - // - // NOTE: this is because we handle locality with weight 0 same as the - // locality doesn't exist. If this changes in the future, this removeSubConn - // behavior will also change. - scToRemove2 := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove2, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want %v, got %v", sc2, scToRemove2) - } - - // Test pick with two subconns different locality weight. - p6 := <-cc.NewPickerCh - // Locality-1 will be not be picked, and locality-2 will be picked. - // Locality-2 contains sc3 and sc4. So expect sc3, sc4. - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p6)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -// The EDS balancer gets EDS resp with unhealthy endpoints. Test that only -// healthy ones are used. -func (s) TestEDS_EndpointsHealth(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // Two localities, each 3 backend, one Healthy, one Unhealthy, one Unknown. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:6], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[6:12], &testutils.AddLocalityOptions{ - Health: []corepb.HealthStatus{ - corepb.HealthStatus_HEALTHY, - corepb.HealthStatus_UNHEALTHY, - corepb.HealthStatus_UNKNOWN, - corepb.HealthStatus_DRAINING, - corepb.HealthStatus_TIMEOUT, - corepb.HealthStatus_DEGRADED, - }, - }) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - var ( - readySCs []balancer.SubConn - newSubConnAddrStrs []string - ) - for i := 0; i < 4; i++ { - addr := <-cc.NewSubConnAddrsCh - newSubConnAddrStrs = append(newSubConnAddrStrs, addr[0].Addr) - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Connecting) - edsb.handleSubConnStateChange(sc, connectivity.Ready) - readySCs = append(readySCs, sc) - } - - wantNewSubConnAddrStrs := []string{ - testEndpointAddrs[0], - testEndpointAddrs[2], - testEndpointAddrs[6], - testEndpointAddrs[8], - } - sortStrTrans := cmp.Transformer("Sort", func(in []string) []string { - out := append([]string(nil), in...) // Copy input to avoid mutating it. - sort.Strings(out) - return out - }) - if !cmp.Equal(newSubConnAddrStrs, wantNewSubConnAddrStrs, sortStrTrans) { - t.Fatalf("want newSubConn with address %v, got %v", wantNewSubConnAddrStrs, newSubConnAddrStrs) - } - - // There should be exactly 4 new SubConns. Check to make sure there's no - // more subconns being created. - select { - case <-cc.NewSubConnCh: - t.Fatalf("Got unexpected new subconn") - case <-time.After(time.Microsecond * 100): - } - - // Test roundrobin with the subconns. - p1 := <-cc.NewPickerCh - want := readySCs - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestClose(t *testing.T) { - edsb := newEDSBalancerImpl(nil, balancer.BuildOptions{}, nil, nil, nil) - // This is what could happen when switching between fallback and eds. This - // make sure it doesn't panic. - edsb.close() -} - -// TestEDS_EmptyUpdate covers the cases when eds impl receives an empty update. -// -// It should send an error picker with transient failure to the parent. -func (s) TestEDS_EmptyUpdate(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // The first update is an empty update. - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr0.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc1) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - edsb.handleEDSResponse(xdsclient.EndpointsUpdate{}) - // Pick should fail with transient failure, and all priority removed error. - perr1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := perr1.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(err, errAllPrioritiesRemoved) { - t.Fatalf("picker.Pick, got error %v, want error %v", err, errAllPrioritiesRemoved) - } - } - - // Handle another update with priorities and localities. - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Pick with only the first backend. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !reflect.DeepEqual(gotSCSt.SubConn, sc2) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } -} - -// Create XDS balancer, and update sub-balancer before handling eds responses. -// Then switch between round-robin and a test stub-balancer after handling first -// eds response. -func (s) TestEDS_UpdateSubBalancerName(t *testing.T) { - const balancerName = "stubBalancer-TestEDS_UpdateSubBalancerName" - - cc := testutils.NewTestClientConn(t) - stub.Register(balancerName, stub.BalancerFuncs{ - UpdateClientConnState: func(bd *stub.BalancerData, s balancer.ClientConnState) error { - if len(s.ResolverState.Addresses) == 0 { - return nil - } - bd.ClientConn.NewSubConn(s.ResolverState.Addresses, balancer.NewSubConnOptions{}) - return nil - }, - UpdateSubConnState: func(bd *stub.BalancerData, sc balancer.SubConn, state balancer.SubConnState) { - bd.ClientConn.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &testutils.TestConstPicker{Err: testutils.ErrTestConstPicker}, - }) - }, - }) - - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %+v, want err %+v", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - want := []balancer.SubConn{sc1, sc2} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p1)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } - - t.Logf("update sub-balancer to stub-balancer") - edsb.handleChildPolicy(balancerName, nil) - - for i := 0; i < 2; i++ { - scToRemove := <-cc.RemoveSubConnCh - if !cmp.Equal(scToRemove, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) && - !cmp.Equal(scToRemove, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("RemoveSubConn, want (%v or %v), got %v", sc1, sc2, scToRemove) - } - edsb.handleSubConnStateChange(scToRemove, connectivity.Shutdown) - } - - for i := 0; i < 2; i++ { - sc := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc, connectivity.Ready) - } - - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p2.Pick(balancer.PickInfo{}) - if err != testutils.ErrTestConstPicker { - t.Fatalf("picker.Pick, got err %q, want err %q", err, testutils.ErrTestConstPicker) - } - } - - t.Logf("update sub-balancer to round-robin") - edsb.handleChildPolicy(roundrobin.Name, nil) - - for i := 0; i < 2; i++ { - <-cc.RemoveSubConnCh - } - - sc3 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc3, connectivity.Connecting) - edsb.handleSubConnStateChange(sc3, connectivity.Ready) - sc4 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc4, connectivity.Connecting) - edsb.handleSubConnStateChange(sc4, connectivity.Ready) - - p3 := <-cc.NewPickerCh - want = []balancer.SubConn{sc3, sc4} - if err := testutils.IsRoundRobin(want, subConnFromPicker(p3)); err != nil { - t.Fatalf("want %v, got %v", want, err) - } -} - -func (s) TestEDS_CircuitBreaking(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - var maxRequests uint32 = 50 - edsb.updateServiceRequestsConfig("test", &maxRequests) - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Picks with drops. - dones := []func(){} - p := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if i < 50 && err != nil { - t.Errorf("The first 50%% picks should be non-drops, got error %v", err) - } else if i > 50 && err == nil { - t.Errorf("The second 50%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 50; i++ { - pr, err := p.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The third 50%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } - - // Send another update, with only circuit breaking update (and no picker - // update afterwards). Make sure the new picker uses the new configs. - var maxRequests2 uint32 = 10 - edsb.updateServiceRequestsConfig("test", &maxRequests2) - - // Picks with drops. - dones = []func(){} - p2 := <-cc.NewPickerCh - for i := 0; i < 100; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if i < 10 && err != nil { - t.Errorf("The first 10%% picks should be non-drops, got error %v", err) - } else if i > 10 && err == nil { - t.Errorf("The next 90%% picks should be drops, got error ") - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - for _, done := range dones { - done() - } - dones = []func(){} - - // Pick without drops. - for i := 0; i < 10; i++ { - pr, err := p2.Pick(balancer.PickInfo{}) - if err != nil { - t.Errorf("The next 10%% picks should be non-drops, got error %v", err) - } - dones = append(dones, func() { - if pr.Done != nil { - pr.Done(balancer.DoneInfo{}) - } - }) - } - - // Without this, future tests with the same service name will fail. - for _, done := range dones { - done() - } -} - -func init() { - balancer.Register(&testInlineUpdateBalancerBuilder{}) -} - -// A test balancer that updates balancer.State inline when handling ClientConn -// state. -type testInlineUpdateBalancerBuilder struct{} - -func (*testInlineUpdateBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &testInlineUpdateBalancer{cc: cc} -} - -func (*testInlineUpdateBalancerBuilder) Name() string { - return "test-inline-update-balancer" -} - -type testInlineUpdateBalancer struct { - cc balancer.ClientConn -} - -func (tb *testInlineUpdateBalancer) ResolverError(error) { - panic("not implemented") -} - -func (tb *testInlineUpdateBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { -} - -var errTestInlineStateUpdate = fmt.Errorf("don't like addresses, empty or not") - -func (tb *testInlineUpdateBalancer) UpdateClientConnState(balancer.ClientConnState) error { - tb.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &testutils.TestConstPicker{Err: errTestInlineStateUpdate}, - }) - return nil -} - -func (*testInlineUpdateBalancer) Close() { -} - -// When the child policy update picker inline in a handleClientUpdate call -// (e.g., roundrobin handling empty addresses). There could be deadlock caused -// by acquiring a locked mutex. -func (s) TestEDS_ChildPolicyUpdatePickerInline(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = func(p priorityType, state balancer.State) { - // For this test, euqueue needs to happen asynchronously (like in the - // real implementation). - go edsb.updateState(p, state) - } - - edsb.handleChildPolicy("test-inline-update-balancer", nil) - - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - p0 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - _, err := p0.Pick(balancer.PickInfo{}) - if err != errTestInlineStateUpdate { - t.Fatalf("picker.Pick, got err %q, want err %q", err, errTestInlineStateUpdate) - } - } -} - -func (s) TestDropPicker(t *testing.T) { - const pickCount = 12 - var constPicker = &testutils.TestConstPicker{ - SC: testutils.TestSubConns[0], - } - - tests := []struct { - name string - drops []*dropper - }{ - { - name: "no drop", - drops: nil, - }, - { - name: "one drop", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "two drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - { - name: "three drops", - drops: []*dropper{ - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 3}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 4}), - newDropper(xdsclient.OverloadDropConfig{Numerator: 1, Denominator: 2}), - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - p := newDropPicker(constPicker, tt.drops, nil, nil, defaultServiceRequestCountMax) - - // scCount is the number of sc's returned by pick. The opposite of - // drop-count. - var ( - scCount int - wantCount = pickCount - ) - for _, dp := range tt.drops { - wantCount = wantCount * int(dp.c.Denominator-dp.c.Numerator) / int(dp.c.Denominator) - } - - for i := 0; i < pickCount; i++ { - _, err := p.Pick(balancer.PickInfo{}) - if err == nil { - scCount++ - } - } - - if scCount != (wantCount) { - t.Errorf("drops: %+v, scCount %v, wantCount %v", tt.drops, scCount, wantCount) - } - }) - } -} - -func (s) TestEDS_LoadReport(t *testing.T) { - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() - - // We create an xdsClientWrapper with a dummy xdsClientInterface which only - // implements the LoadStore() method to return the underlying load.Store to - // be used. - loadStore := load.NewStore() - lsWrapper := loadstore.NewWrapper() - lsWrapper.UpdateClusterAndService(testClusterNames[0], "") - lsWrapper.UpdateLoadStore(loadStore) - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - const ( - testServiceName = "test-service" - cbMaxRequests = 20 - ) - var maxRequestsTemp uint32 = cbMaxRequests - edsb.updateServiceRequestsConfig(testServiceName, &maxRequestsTemp) - defer client.ClearCounterForTesting(testServiceName) - - backendToBalancerID := make(map[balancer.SubConn]xdsinternal.LocalityID) - - const testDropCategory = "test-drop" - // Two localities, each with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], map[string]uint32{testDropCategory: 50}) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - locality1 := xdsinternal.LocalityID{SubZone: testSubZones[0]} - backendToBalancerID[sc1] = locality1 - - // Add the second locality later to make sure sc2 belongs to the second - // locality. Otherwise the test is flaky because of a map is used in EDS to - // keep localities. - clab1.AddLocality(testSubZones[1], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - locality2 := xdsinternal.LocalityID{SubZone: testSubZones[1]} - backendToBalancerID[sc2] = locality2 - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We expect the 10 picks to be split between the localities since they are - // of equal weight. And since we only mark the picks routed to sc2 as done, - // the picks on sc1 should show up as inProgress. - locality1JSON, _ := locality1.ToString() - locality2JSON, _ := locality2.ToString() - const ( - rpcCount = 100 - // 50% will be dropped with category testDropCategory. - dropWithCategory = rpcCount / 2 - // In the remaining RPCs, only cbMaxRequests are allowed by circuit - // breaking. Others will be dropped by CB. - dropWithCB = rpcCount - dropWithCategory - cbMaxRequests - - rpcInProgress = cbMaxRequests / 2 // 50% of RPCs will be never done. - rpcSucceeded = cbMaxRequests / 2 // 50% of RPCs will succeed. - ) - wantStoreData := []*load.Data{{ - Cluster: testClusterNames[0], - Service: "", - LocalityStats: map[string]load.LocalityData{ - locality1JSON: {RequestStats: load.RequestData{InProgress: rpcInProgress}}, - locality2JSON: {RequestStats: load.RequestData{Succeeded: rpcSucceeded}}, - }, - TotalDrops: dropWithCategory + dropWithCB, - Drops: map[string]uint64{ - testDropCategory: dropWithCategory, - }, - }} - - var rpcsToBeDone []balancer.PickResult - // Run the picks, but only pick with sc1 will be done later. - for i := 0; i < rpcCount; i++ { - scst, _ := p1.Pick(balancer.PickInfo{}) - if scst.Done != nil && scst.SubConn != sc1 { - rpcsToBeDone = append(rpcsToBeDone, scst) - } - } - // Call done on those sc1 picks. - for _, scst := range rpcsToBeDone { - scst.Done(balancer.DoneInfo{}) - } - - gotStoreData := loadStore.Stats(testClusterNames[0:1]) - if diff := cmp.Diff(wantStoreData, gotStoreData, cmpopts.EquateEmpty(), cmpopts.IgnoreFields(load.Data{}, "ReportInterval")); diff != "" { - t.Errorf("store.stats() returned unexpected diff (-want +got):\n%s", diff) - } -} - -// TestEDS_LoadReportDisabled covers the case that LRS is disabled. It makes -// sure the EDS implementation isn't broken (doesn't panic). -func (s) TestEDS_LoadReportDisabled(t *testing.T) { - lsWrapper := loadstore.NewWrapper() - lsWrapper.UpdateClusterAndService(testClusterNames[0], "") - // Not calling lsWrapper.updateLoadStore(loadStore) because LRS is disabled. - - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, lsWrapper, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - // One localities, with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Test roundrobin with two subconns. - p1 := <-cc.NewPickerCh - // We call picks to make sure they don't panic. - for i := 0; i < 10; i++ { - p1.Pick(balancer.PickInfo{}) - } -} - -// TestEDS_ClusterNameInAddressAttributes covers the case that cluster name is -// attached to the subconn address attributes. -func (s) TestEDS_ClusterNameInAddressAttributes(t *testing.T) { - cc := testutils.NewTestClientConn(t) - edsb := newEDSBalancerImpl(cc, balancer.BuildOptions{}, nil, nil, nil) - edsb.enqueueChildBalancerStateUpdate = edsb.updateState - - const clusterName1 = "cluster-name-1" - edsb.updateClusterName(clusterName1) - - // One locality with one backend. - clab1 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab1.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[:1], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab1.Build())) - - addrs1 := <-cc.NewSubConnAddrsCh - if got, want := addrs1[0].Addr, testEndpointAddrs[0]; got != want { - t.Fatalf("sc is created with addr %v, want %v", got, want) - } - cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes) - if !ok || cn != clusterName1 { - t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, clusterName1) - } - - sc1 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc1, connectivity.Connecting) - edsb.handleSubConnStateChange(sc1, connectivity.Ready) - - // Pick with only the first backend. - p1 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - } - - // Change cluster name. - const clusterName2 = "cluster-name-2" - edsb.updateClusterName(clusterName2) - - // Change backend. - clab2 := testutils.NewClusterLoadAssignmentBuilder(testClusterNames[0], nil) - clab2.AddLocality(testSubZones[0], 1, 0, testEndpointAddrs[1:2], nil) - edsb.handleEDSResponse(parseEDSRespProtoForTesting(clab2.Build())) - - addrs2 := <-cc.NewSubConnAddrsCh - if got, want := addrs2[0].Addr, testEndpointAddrs[1]; got != want { - t.Fatalf("sc is created with addr %v, want %v", got, want) - } - // New addresses should have the new cluster name. - cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes) - if !ok || cn2 != clusterName2 { - t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, clusterName1) - } - - sc2 := <-cc.NewSubConnCh - edsb.handleSubConnStateChange(sc2, connectivity.Connecting) - edsb.handleSubConnStateChange(sc2, connectivity.Ready) - - // Test roundrobin with two subconns. - p2 := <-cc.NewPickerCh - for i := 0; i < 5; i++ { - gotSCSt, _ := p2.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc2, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc2) - } - } -} diff --git a/xds/internal/balancer/edsbalancer/eds_test.go b/xds/internal/balancer/edsbalancer/eds_test.go deleted file mode 100644 index 3fe66098973c..000000000000 --- a/xds/internal/balancer/edsbalancer/eds_test.go +++ /dev/null @@ -1,933 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "reflect" - "testing" - "time" - - "github.com/golang/protobuf/jsonpb" - wrapperspb "github.com/golang/protobuf/ptypes/wrappers" - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/grpctest" - scpb "google.golang.org/grpc/internal/proto/grpc_service_config" - "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" - - _ "google.golang.org/grpc/xds/internal/client/v2" // V2 client registration. -) - -const ( - defaultTestTimeout = 1 * time.Second - defaultTestShortTimeout = 10 * time.Millisecond - testServiceName = "test/foo" - testEDSClusterName = "test/service/eds" -) - -var ( - // A non-empty endpoints update which is expected to be accepted by the EDS - // LB policy. - defaultEndpointsUpdate = xdsclient.EndpointsUpdate{ - Localities: []xdsclient.Locality{ - { - Endpoints: []xdsclient.Endpoint{{Address: "endpoint1"}}, - ID: internal.LocalityID{Zone: "zone"}, - Priority: 1, - Weight: 100, - }, - }, - } -) - -func init() { - balancer.Register(&edsBalancerBuilder{}) -} - -func subConnFromPicker(p balancer.Picker) func() balancer.SubConn { - return func() balancer.SubConn { - scst, _ := p.Pick(balancer.PickInfo{}) - return scst.SubConn - } -} - -type s struct { - grpctest.Tester -} - -func Test(t *testing.T) { - grpctest.RunSubTests(t, s{}) -} - -const testBalancerNameFooBar = "foo.bar" - -func newNoopTestClientConn() *noopTestClientConn { - return &noopTestClientConn{} -} - -// noopTestClientConn is used in EDS balancer config update tests that only -// cover the config update handling, but not SubConn/load-balancing. -type noopTestClientConn struct { - balancer.ClientConn -} - -func (t *noopTestClientConn) NewSubConn([]resolver.Address, balancer.NewSubConnOptions) (balancer.SubConn, error) { - return nil, nil -} - -func (noopTestClientConn) Target() string { return testServiceName } - -type scStateChange struct { - sc balancer.SubConn - state connectivity.State -} - -type fakeEDSBalancer struct { - cc balancer.ClientConn - childPolicy *testutils.Channel - subconnStateChange *testutils.Channel - edsUpdate *testutils.Channel - serviceName *testutils.Channel - serviceRequestMax *testutils.Channel - clusterName *testutils.Channel -} - -func (f *fakeEDSBalancer) handleSubConnStateChange(sc balancer.SubConn, state connectivity.State) { - f.subconnStateChange.Send(&scStateChange{sc: sc, state: state}) -} - -func (f *fakeEDSBalancer) handleChildPolicy(name string, config json.RawMessage) { - f.childPolicy.Send(&loadBalancingConfig{Name: name, Config: config}) -} - -func (f *fakeEDSBalancer) handleEDSResponse(edsResp xdsclient.EndpointsUpdate) { - f.edsUpdate.Send(edsResp) -} - -func (f *fakeEDSBalancer) updateState(priority priorityType, s balancer.State) {} - -func (f *fakeEDSBalancer) updateServiceRequestsConfig(serviceName string, max *uint32) { - f.serviceName.Send(serviceName) - f.serviceRequestMax.Send(max) -} - -func (f *fakeEDSBalancer) updateClusterName(name string) { - f.clusterName.Send(name) -} - -func (f *fakeEDSBalancer) close() {} - -func (f *fakeEDSBalancer) waitForChildPolicy(ctx context.Context, wantPolicy *loadBalancingConfig) error { - val, err := f.childPolicy.Receive(ctx) - if err != nil { - return err - } - gotPolicy := val.(*loadBalancingConfig) - if !cmp.Equal(gotPolicy, wantPolicy) { - return fmt.Errorf("got childPolicy %v, want %v", gotPolicy, wantPolicy) - } - return nil -} - -func (f *fakeEDSBalancer) waitForSubConnStateChange(ctx context.Context, wantState *scStateChange) error { - val, err := f.subconnStateChange.Receive(ctx) - if err != nil { - return err - } - gotState := val.(*scStateChange) - if !cmp.Equal(gotState, wantState, cmp.AllowUnexported(scStateChange{})) { - return fmt.Errorf("got subconnStateChange %v, want %v", gotState, wantState) - } - return nil -} - -func (f *fakeEDSBalancer) waitForEDSResponse(ctx context.Context, wantUpdate xdsclient.EndpointsUpdate) error { - val, err := f.edsUpdate.Receive(ctx) - if err != nil { - return err - } - gotUpdate := val.(xdsclient.EndpointsUpdate) - if !reflect.DeepEqual(gotUpdate, wantUpdate) { - return fmt.Errorf("got edsUpdate %+v, want %+v", gotUpdate, wantUpdate) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCounterUpdate(ctx context.Context, wantServiceName string) error { - val, err := f.serviceName.Receive(ctx) - if err != nil { - return err - } - gotServiceName := val.(string) - if gotServiceName != wantServiceName { - return fmt.Errorf("got serviceName %v, want %v", gotServiceName, wantServiceName) - } - return nil -} - -func (f *fakeEDSBalancer) waitForCountMaxUpdate(ctx context.Context, want *uint32) error { - val, err := f.serviceRequestMax.Receive(ctx) - if err != nil { - return err - } - got := val.(*uint32) - - if got == nil && want == nil { - return nil - } - if got != nil && want != nil { - if *got != *want { - return fmt.Errorf("got countMax %v, want %v", *got, *want) - } - return nil - } - return fmt.Errorf("got countMax %+v, want %+v", got, want) -} - -func (f *fakeEDSBalancer) waitForClusterNameUpdate(ctx context.Context, wantClusterName string) error { - val, err := f.clusterName.Receive(ctx) - if err != nil { - return err - } - gotServiceName := val.(string) - if gotServiceName != wantClusterName { - return fmt.Errorf("got clusterName %v, want %v", gotServiceName, wantClusterName) - } - return nil -} - -func newFakeEDSBalancer(cc balancer.ClientConn) edsBalancerImplInterface { - return &fakeEDSBalancer{ - cc: cc, - childPolicy: testutils.NewChannelWithSize(10), - subconnStateChange: testutils.NewChannelWithSize(10), - edsUpdate: testutils.NewChannelWithSize(10), - serviceName: testutils.NewChannelWithSize(10), - serviceRequestMax: testutils.NewChannelWithSize(10), - clusterName: testutils.NewChannelWithSize(10), - } -} - -type fakeSubConn struct{} - -func (*fakeSubConn) UpdateAddresses([]resolver.Address) { panic("implement me") } -func (*fakeSubConn) Connect() { panic("implement me") } - -// waitForNewEDSLB makes sure that a new edsLB is created by the top-level -// edsBalancer. -func waitForNewEDSLB(ctx context.Context, ch *testutils.Channel) (*fakeEDSBalancer, error) { - val, err := ch.Receive(ctx) - if err != nil { - return nil, fmt.Errorf("error when waiting for a new edsLB: %v", err) - } - return val.(*fakeEDSBalancer), nil -} - -// setup overrides the functions which are used to create the xdsClient and the -// edsLB, creates fake version of them and makes them available on the provided -// channels. The returned cancel function should be called by the test for -// cleanup. -func setup(edsLBCh *testutils.Channel) (*fakeclient.Client, func()) { - xdsC := fakeclient.NewClientWithName(testBalancerNameFooBar) - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - - origNewEDSBalancer := newEDSBalancer - newEDSBalancer = func(cc balancer.ClientConn, _ balancer.BuildOptions, _ func(priorityType, balancer.State), _ load.PerClusterReporter, _ *grpclog.PrefixLogger) edsBalancerImplInterface { - edsLB := newFakeEDSBalancer(cc) - defer func() { edsLBCh.Send(edsLB) }() - return edsLB - } - return xdsC, func() { - newEDSBalancer = origNewEDSBalancer - newXDSClient = oldNewXDSClient - } -} - -const ( - fakeBalancerA = "fake_balancer_A" - fakeBalancerB = "fake_balancer_B" -) - -// Install two fake balancers for service config update tests. -// -// ParseConfig only accepts the json if the balancer specified is registered. -func init() { - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerA}) - balancer.Register(&fakeBalancerBuilder{name: fakeBalancerB}) -} - -type fakeBalancerBuilder struct { - name string -} - -func (b *fakeBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - return &fakeBalancer{cc: cc} -} - -func (b *fakeBalancerBuilder) Name() string { - return b.name -} - -type fakeBalancer struct { - cc balancer.ClientConn -} - -func (b *fakeBalancer) ResolverError(error) { - panic("implement me") -} - -func (b *fakeBalancer) UpdateClientConnState(balancer.ClientConnState) error { - panic("implement me") -} - -func (b *fakeBalancer) UpdateSubConnState(balancer.SubConn, balancer.SubConnState) { - panic("implement me") -} - -func (b *fakeBalancer) Close() {} - -// TestConfigChildPolicyUpdate verifies scenarios where the childPolicy -// section of the lbConfig is updated. -// -// The test does the following: -// * Builds a new EDS balancer. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerA. -// Verifies that an EDS watch is registered. It then pushes a new edsUpdate -// through the fakexds client. Verifies that a new edsLB is created and it -// receives the expected childPolicy. -// * Pushes a new ClientConnState with a childPolicy set to fakeBalancerB. -// Verifies that the existing edsLB receives the new child policy. -func (s) TestConfigChildPolicyUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - lbCfgA := &loadBalancingConfig{ - Name: fakeBalancerA, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgA, - ClusterName: testEDSClusterName, - EDSServiceName: testServiceName, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - if err := edsLB.waitForChildPolicy(ctx, lbCfgA); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testEDSClusterName); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, nil); err != nil { - t.Fatal(err) - } - - var testCountMax uint32 = 100 - lbCfgB := &loadBalancingConfig{ - Name: fakeBalancerB, - Config: json.RawMessage("{}"), - } - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ChildPolicy: lbCfgB, - ClusterName: testEDSClusterName, - EDSServiceName: testServiceName, - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - if err := edsLB.waitForChildPolicy(ctx, lbCfgB); err != nil { - t.Fatal(err) - } - if err := edsLB.waitForCounterUpdate(ctx, testEDSClusterName); err != nil { - // Counter is updated even though the service name didn't change. The - // eds_impl will compare the service names, and skip if it didn't change. - t.Fatal(err) - } - if err := edsLB.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -// TestSubConnStateChange verifies if the top-level edsBalancer passes on -// the subConnStateChange to appropriate child balancer. -func (s) TestSubConnStateChange(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatalf("edsB.UpdateClientConnState() failed: %v", err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(defaultEndpointsUpdate, nil) - - fsc := &fakeSubConn{} - state := connectivity.Ready - edsB.UpdateSubConnState(fsc, balancer.SubConnState{ConnectivityState: state}) - if err := edsLB.waitForSubConnStateChange(ctx, &scStateChange{sc: fsc, state: state}); err != nil { - t.Fatal(err) - } -} - -// TestErrorFromXDSClientUpdate verifies that an error from xdsClient update is -// handled correctly. -// -// If it's resource-not-found, watch will NOT be canceled, the EDS impl will -// receive an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromXDSClientUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal(err) - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, resourceErr) - // Even if error is resource not found, watch shouldn't be canceled, because - // this is an EDS resource removed (and xds client actually never sends this - // error, but we still handles it). - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("eds impl expecting empty update, got %v", err) - } - - // An update with the same service name should not trigger a new watch. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if _, err := xdsC.WaitForWatchEDS(sCtx); err != context.DeadlineExceeded { - t.Fatal("got unexpected new EDS watch") - } -} - -// TestErrorFromResolver verifies that resolver errors are handled correctly. -// -// If it's resource-not-found, watch will be canceled, the EDS impl will receive -// an empty EDS update, and new RPCs will fail. -// -// If it's connection error, nothing will happen. This will need to change to -// handle fallback. -func (s) TestErrorFromResolver(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsLB, err := waitForNewEDSLB(ctx, edsLBCh) - if err != nil { - t.Fatal(err) - } - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - xdsC.InvokeWatchEDSCallback(xdsclient.EndpointsUpdate{}, nil) - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - connectionErr := xdsclient.NewErrorf(xdsclient.ErrorTypeConnection, "connection error") - edsB.ResolverError(connectionErr) - - sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := xdsC.WaitForCancelEDSWatch(sCtx); err != context.DeadlineExceeded { - t.Fatal("watch was canceled, want not canceled (timeout error)") - } - - sCtx, sCancel = context.WithTimeout(context.Background(), defaultTestShortTimeout) - defer sCancel() - if err := edsLB.waitForEDSResponse(sCtx, xdsclient.EndpointsUpdate{}); err != context.DeadlineExceeded { - t.Fatal("eds impl got EDS resp, want timeout error") - } - - resourceErr := xdsclient.NewErrorf(xdsclient.ErrorTypeResourceNotFound, "edsBalancer resource not found error") - edsB.ResolverError(resourceErr) - if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { - t.Fatalf("want watch to be canceled, waitForCancel failed: %v", err) - } - if err := edsLB.waitForEDSResponse(ctx, xdsclient.EndpointsUpdate{}); err != nil { - t.Fatalf("EDS impl got unexpected EDS response: %v", err) - } - - // An update with the same service name should trigger a new watch, because - // the previous watch was canceled. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: testServiceName}, - }); err != nil { - t.Fatal(err) - } - if _, err := xdsC.WaitForWatchEDS(ctx); err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } -} - -// Given a list of resource names, verifies that EDS requests for the same are -// sent by the EDS balancer, through the fake xDS client. -func verifyExpectedRequests(ctx context.Context, fc *fakeclient.Client, resourceNames ...string) error { - for _, name := range resourceNames { - if name == "" { - // ResourceName empty string indicates a cancel. - if err := fc.WaitForCancelEDSWatch(ctx); err != nil { - return fmt.Errorf("timed out when expecting resource %q", name) - } - continue - } - - resName, err := fc.WaitForWatchEDS(ctx) - if err != nil { - return fmt.Errorf("timed out when expecting resource %q, %p", name, fc) - } - if resName != name { - return fmt.Errorf("got EDS request for resource %q, expected: %q", resName, name) - } - } - return nil -} - -// TestClientWatchEDS verifies that the xdsClient inside the top-level EDS LB -// policy registers an EDS watch for expected resource upon receiving an update -// from gRPC. -func (s) TestClientWatchEDS(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - // If eds service name is not set, should watch for cluster name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ClusterName: "cluster-1"}, - }); err != nil { - t.Fatal(err) - } - if err := verifyExpectedRequests(ctx, xdsC, "cluster-1"); err != nil { - t.Fatal(err) - } - - // Update with an non-empty edsServiceName should trigger an EDS watch for - // the same. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-1"}, - }); err != nil { - t.Fatal(err) - } - if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-1"); err != nil { - t.Fatal(err) - } - - // Also test the case where the edsServerName changes from one non-empty - // name to another, and make sure a new watch is registered. The previously - // registered watch will be cancelled, which will result in an EDS request - // with no resource names being sent to the server. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{EDSServiceName: "foobar-2"}, - }); err != nil { - t.Fatal(err) - } - if err := verifyExpectedRequests(ctx, xdsC, "", "foobar-2"); err != nil { - t.Fatal(err) - } -} - -// TestCounterUpdate verifies that the counter update is triggered with the -// service name from an update's config. -func (s) TestCounterUpdate(t *testing.T) { - edsLBCh := testutils.NewChannel() - _, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - var testCountMax uint32 = 100 - // Update should trigger counter update with provided service name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ClusterName: "foobar-1", - MaxConcurrentRequests: &testCountMax, - }, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - edsI := edsB.(*edsBalancer).edsImpl.(*fakeEDSBalancer) - if err := edsI.waitForCounterUpdate(ctx, "foobar-1"); err != nil { - t.Fatal(err) - } - if err := edsI.waitForCountMaxUpdate(ctx, &testCountMax); err != nil { - t.Fatal(err) - } -} - -// TestClusterNameUpdateInAddressAttributes verifies that cluster name update in -// edsImpl is triggered with the update from a new service config. -func (s) TestClusterNameUpdateInAddressAttributes(t *testing.T) { - edsLBCh := testutils.NewChannel() - xdsC, cleanup := setup(edsLBCh) - defer cleanup() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{Target: resolver.Target{Endpoint: testServiceName}}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - // Update should trigger counter update with provided service name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ClusterName: "foobar-1", - }, - }); err != nil { - t.Fatal(err) - } - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - gotCluster, err := xdsC.WaitForWatchEDS(ctx) - if err != nil || gotCluster != "foobar-1" { - t.Fatalf("unexpected EDS watch: %v, %v", gotCluster, err) - } - edsI := edsB.(*edsBalancer).edsImpl.(*fakeEDSBalancer) - if err := edsI.waitForClusterNameUpdate(ctx, "foobar-1"); err != nil { - t.Fatal(err) - } - - // Update should trigger counter update with provided service name. - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - ClusterName: "foobar-2", - }, - }); err != nil { - t.Fatal(err) - } - if err := xdsC.WaitForCancelEDSWatch(ctx); err != nil { - t.Fatalf("failed to wait for EDS cancel: %v", err) - } - gotCluster2, err := xdsC.WaitForWatchEDS(ctx) - if err != nil || gotCluster2 != "foobar-2" { - t.Fatalf("unexpected EDS watch: %v, %v", gotCluster2, err) - } - if err := edsI.waitForClusterNameUpdate(ctx, "foobar-2"); err != nil { - t.Fatal(err) - } -} - -func (s) TestBalancerConfigParsing(t *testing.T) { - const testEDSName = "eds.service" - var testLRSName = "lrs.server" - b := bytes.NewBuffer(nil) - if err := (&jsonpb.Marshaler{}).Marshal(b, &scpb.XdsConfig{ - ChildPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_RoundRobin{ - RoundRobin: &scpb.RoundRobinConfig{}, - }}, - }, - FallbackPolicy: []*scpb.LoadBalancingConfig{ - {Policy: &scpb.LoadBalancingConfig_Xds{}}, - {Policy: &scpb.LoadBalancingConfig_PickFirst{ - PickFirst: &scpb.PickFirstConfig{}, - }}, - }, - EdsServiceName: testEDSName, - LrsLoadReportingServerName: &wrapperspb.StringValue{Value: testLRSName}, - }); err != nil { - t.Fatalf("%v", err) - } - - var testMaxConcurrentRequests uint32 = 123 - tests := []struct { - name string - js json.RawMessage - want serviceconfig.LoadBalancingConfig - wantErr bool - }{ - { - name: "bad json", - js: json.RawMessage(`i am not JSON`), - wantErr: true, - }, - { - name: "empty", - js: json.RawMessage(`{}`), - want: &EDSConfig{}, - }, - { - name: "jsonpb-generated", - js: b.Bytes(), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with random balancers, and the first is not registered. - name: "manually-generated", - js: json.RawMessage(` -{ - "childPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_A": {}}, - {"fake_balancer_B": {}} - ], - "fallbackPolicy": [ - {"fake_balancer_C": {}}, - {"fake_balancer_B": {}}, - {"fake_balancer_A": {}} - ], - "edsServiceName": "eds.service", - "maxConcurrentRequests": 123, - "lrsLoadReportingServerName": "lrs.server" -}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "fake_balancer_A", - Config: json.RawMessage("{}"), - }, - FallBackPolicy: &loadBalancingConfig{ - Name: "fake_balancer_B", - Config: json.RawMessage("{}"), - }, - EDSServiceName: testEDSName, - MaxConcurrentRequests: &testMaxConcurrentRequests, - LrsLoadReportingServerName: &testLRSName, - }, - }, - { - // json with no lrs server name, LoadReportingServerName should - // be nil (not an empty string). - name: "no-lrs-server-name", - js: json.RawMessage(` -{ - "edsServiceName": "eds.service" -}`), - want: &EDSConfig{ - EDSServiceName: testEDSName, - LrsLoadReportingServerName: nil, - }, - }, - { - name: "good child policy", - js: json.RawMessage(`{"childPolicy":[{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "pick_first", - Config: json.RawMessage(`{}`), - }, - }, - }, - { - name: "multiple good child policies", - js: json.RawMessage(`{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`), - want: &EDSConfig{ - ChildPolicy: &loadBalancingConfig{ - Name: "round_robin", - Config: json.RawMessage(`{}`), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - b := &edsBalancerBuilder{} - got, err := b.ParseConfig(tt.js) - if (err != nil) != tt.wantErr { - t.Fatalf("edsBalancerBuilder.ParseConfig() error = %v, wantErr %v", err, tt.wantErr) - } - if tt.wantErr { - return - } - if !cmp.Equal(got, tt.want) { - t.Errorf(cmp.Diff(got, tt.want)) - } - }) - } -} - -func (s) TestEqualStringPointers(t *testing.T) { - var ( - ta1 = "test-a" - ta2 = "test-a" - tb = "test-b" - ) - tests := []struct { - name string - a *string - b *string - want bool - }{ - {"both-nil", nil, nil, true}, - {"a-non-nil", &ta1, nil, false}, - {"b-non-nil", nil, &tb, false}, - {"equal", &ta1, &ta2, true}, - {"different", &ta1, &tb, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := equalStringPointers(tt.a, tt.b); got != tt.want { - t.Errorf("equalStringPointers() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/util.go b/xds/internal/balancer/edsbalancer/util.go deleted file mode 100644 index 132950426466..000000000000 --- a/xds/internal/balancer/edsbalancer/util.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "google.golang.org/grpc/internal/wrr" - xdsclient "google.golang.org/grpc/xds/internal/client" -) - -var newRandomWRR = wrr.NewRandom - -type dropper struct { - c xdsclient.OverloadDropConfig - w wrr.WRR -} - -func newDropper(c xdsclient.OverloadDropConfig) *dropper { - w := newRandomWRR() - w.Add(true, int64(c.Numerator)) - w.Add(false, int64(c.Denominator-c.Numerator)) - - return &dropper{ - c: c, - w: w, - } -} - -func (d *dropper) drop() (ret bool) { - return d.w.Next().(bool) -} diff --git a/xds/internal/balancer/edsbalancer/util_test.go b/xds/internal/balancer/edsbalancer/util_test.go deleted file mode 100644 index b94905d49889..000000000000 --- a/xds/internal/balancer/edsbalancer/util_test.go +++ /dev/null @@ -1,90 +0,0 @@ -// +build go1.12 - -/* - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import ( - "testing" - - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/testutils" -) - -func init() { - newRandomWRR = testutils.NewTestWRR -} - -func (s) TestDropper(t *testing.T) { - const repeat = 2 - - type args struct { - numerator uint32 - denominator uint32 - } - tests := []struct { - name string - args args - }{ - { - name: "2_3", - args: args{ - numerator: 2, - denominator: 3, - }, - }, - { - name: "4_8", - args: args{ - numerator: 4, - denominator: 8, - }, - }, - { - name: "7_20", - args: args{ - numerator: 7, - denominator: 20, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - d := newDropper(xdsclient.OverloadDropConfig{ - Category: "", - Numerator: tt.args.numerator, - Denominator: tt.args.denominator, - }) - - var ( - dCount int - wantCount = int(tt.args.numerator) * repeat - loopCount = int(tt.args.denominator) * repeat - ) - for i := 0; i < loopCount; i++ { - if d.drop() { - dCount++ - } - } - - if dCount != (wantCount) { - t.Errorf("with numerator %v, denominator %v repeat %v, got drop count: %v, want %v", - tt.args.numerator, tt.args.denominator, repeat, dCount, wantCount) - } - }) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_lrs_test.go b/xds/internal/balancer/edsbalancer/xds_lrs_test.go deleted file mode 100644 index 8b7aab657667..000000000000 --- a/xds/internal/balancer/edsbalancer/xds_lrs_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package edsbalancer - -import ( - "context" - "testing" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -// TestXDSLoadReporting verifies that the edsBalancer starts the loadReport -// stream when the lbConfig passed to it contains a valid value for the LRS -// server (empty string). -func (s) TestXDSLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(edsName) - edsB := builder.Build(newNoopTestClientConn(), balancer.BuildOptions{}) - if edsB == nil { - t.Fatalf("builder.Build(%s) failed and returned nil", edsName) - } - defer edsB.Close() - - if err := edsB.UpdateClientConnState(balancer.ClientConnState{ - BalancerConfig: &EDSConfig{ - EDSServiceName: testEDSClusterName, - LrsLoadReportingServerName: new(string), - }, - }); err != nil { - t.Fatal(err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - gotCluster, err := xdsC.WaitForWatchEDS(ctx) - if err != nil { - t.Fatalf("xdsClient.WatchEndpoints failed with error: %v", err) - } - if gotCluster != testEDSClusterName { - t.Fatalf("xdsClient.WatchEndpoints() called with cluster: %v, want %v", gotCluster, testEDSClusterName) - } - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != "" { - t.Fatalf("xdsClient.ReportLoad called with {%v}: want {\"\"}", got.Server) - } -} diff --git a/xds/internal/balancer/edsbalancer/xds_old.go b/xds/internal/balancer/edsbalancer/xds_old.go deleted file mode 100644 index 6729e6801f15..000000000000 --- a/xds/internal/balancer/edsbalancer/xds_old.go +++ /dev/null @@ -1,46 +0,0 @@ -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package edsbalancer - -import "google.golang.org/grpc/balancer" - -// The old xds balancer implements logic for both CDS and EDS. With the new -// design, CDS is split and moved to a separate balancer, and the xds balancer -// becomes the EDS balancer. -// -// To keep the existing tests working, this file regisger EDS balancer under the -// old xds balancer name. -// -// TODO: delete this file when migration to new workflow (LDS, RDS, CDS, EDS) is -// done. - -const xdsName = "xds_experimental" - -func init() { - balancer.Register(&xdsBalancerBuilder{}) -} - -// xdsBalancerBuilder register edsBalancerBuilder (now with name -// "eds_experimental") under the old name "xds_experimental". -type xdsBalancerBuilder struct { - edsBalancerBuilder -} - -func (b *xdsBalancerBuilder) Name() string { - return xdsName -} diff --git a/xds/internal/balancer/loadstore/load_store_wrapper.go b/xds/internal/balancer/loadstore/load_store_wrapper.go index 88fa344118cc..8ce958d71ca8 100644 --- a/xds/internal/balancer/loadstore/load_store_wrapper.go +++ b/xds/internal/balancer/loadstore/load_store_wrapper.go @@ -22,7 +22,7 @@ package loadstore import ( "sync" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // NewWrapper creates a Wrapper. diff --git a/xds/internal/balancer/lrs/balancer.go b/xds/internal/balancer/lrs/balancer.go deleted file mode 100644 index 0642c54ed111..000000000000 --- a/xds/internal/balancer/lrs/balancer.go +++ /dev/null @@ -1,249 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package lrs implements load reporting balancer for xds. -package lrs - -import ( - "encoding/json" - "fmt" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal/balancer/loadstore" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/load" -) - -func init() { - balancer.Register(&lrsBB{}) -} - -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } - -// Name is the name of the LRS balancer. -const Name = "lrs_experimental" - -type lrsBB struct{} - -func (l *lrsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - b := &lrsBalancer{ - cc: cc, - buildOpts: opts, - } - b.logger = prefixLogger(b) - b.logger.Infof("Created") - - client, err := newXDSClient() - if err != nil { - b.logger.Errorf("failed to create xds-client: %v", err) - return nil - } - b.client = newXDSClientWrapper(client) - - return b -} - -func (l *lrsBB) Name() string { - return Name -} - -func (l *lrsBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - return parseConfig(c) -} - -type lrsBalancer struct { - cc balancer.ClientConn - buildOpts balancer.BuildOptions - - logger *grpclog.PrefixLogger - client *xdsClientWrapper - - config *LBConfig - lb balancer.Balancer // The sub balancer. -} - -func (b *lrsBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) - newConfig, ok := s.BalancerConfig.(*LBConfig) - if !ok { - return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) - } - - // Update load reporting config or xds client. This needs to be done before - // updating the child policy because we need the loadStore from the updated - // client to be passed to the ccWrapper. - if err := b.client.update(newConfig); err != nil { - return err - } - - // If child policy is a different type, recreate the sub-balancer. - if b.config == nil || b.config.ChildPolicy.Name != newConfig.ChildPolicy.Name { - bb := balancer.Get(newConfig.ChildPolicy.Name) - if bb == nil { - return fmt.Errorf("balancer %q not registered", newConfig.ChildPolicy.Name) - } - if b.lb != nil { - b.lb.Close() - } - lidJSON, err := newConfig.Locality.ToString() - if err != nil { - return fmt.Errorf("failed to marshal LocalityID: %#v", newConfig.Locality) - } - ccWrapper := newCCWrapper(b.cc, b.client.loadStore(), lidJSON) - b.lb = bb.Build(ccWrapper, b.buildOpts) - } - b.config = newConfig - - // Addresses and sub-balancer config are sent to sub-balancer. - return b.lb.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: s.ResolverState, - BalancerConfig: b.config.ChildPolicy.Config, - }) -} - -func (b *lrsBalancer) ResolverError(err error) { - if b.lb != nil { - b.lb.ResolverError(err) - } -} - -func (b *lrsBalancer) UpdateSubConnState(sc balancer.SubConn, s balancer.SubConnState) { - if b.lb != nil { - b.lb.UpdateSubConnState(sc, s) - } -} - -func (b *lrsBalancer) Close() { - if b.lb != nil { - b.lb.Close() - b.lb = nil - } - b.client.close() -} - -type ccWrapper struct { - balancer.ClientConn - loadStore load.PerClusterReporter - localityIDJSON string -} - -func newCCWrapper(cc balancer.ClientConn, loadStore load.PerClusterReporter, localityIDJSON string) *ccWrapper { - return &ccWrapper{ - ClientConn: cc, - loadStore: loadStore, - localityIDJSON: localityIDJSON, - } -} - -func (ccw *ccWrapper) UpdateState(s balancer.State) { - s.Picker = newLoadReportPicker(s.Picker, ccw.localityIDJSON, ccw.loadStore) - ccw.ClientConn.UpdateState(s) -} - -// xdsClientInterface contains only the xds_client methods needed by LRS -// balancer. It's defined so we can override xdsclient in tests. -type xdsClientInterface interface { - ReportLoad(server string) (*load.Store, func()) - Close() -} - -type xdsClientWrapper struct { - c xdsClientInterface - cancelLoadReport func() - clusterName string - edsServiceName string - lrsServerName *string - // loadWrapper is a wrapper with loadOriginal, with clusterName and - // edsServiceName. It's used children to report loads. - loadWrapper *loadstore.Wrapper -} - -func newXDSClientWrapper(c xdsClientInterface) *xdsClientWrapper { - return &xdsClientWrapper{ - c: c, - loadWrapper: loadstore.NewWrapper(), - } -} - -// update checks the config and xdsclient, and decides whether it needs to -// restart the load reporting stream. -func (w *xdsClientWrapper) update(newConfig *LBConfig) error { - var ( - restartLoadReport bool - updateLoadClusterAndService bool - ) - - // ClusterName is different, restart. ClusterName is from ClusterName and - // EDSServiceName. - if w.clusterName != newConfig.ClusterName { - updateLoadClusterAndService = true - w.clusterName = newConfig.ClusterName - } - if w.edsServiceName != newConfig.EDSServiceName { - updateLoadClusterAndService = true - w.edsServiceName = newConfig.EDSServiceName - } - - if updateLoadClusterAndService { - // This updates the clusterName and serviceName that will reported for the - // loads. The update here is too early, the perfect timing is when the - // picker is updated with the new connection. But from this balancer's point - // of view, it's impossible to tell. - // - // On the other hand, this will almost never happen. Each LRS policy - // shouldn't get updated config. The parent should do a graceful switch when - // the clusterName or serviceName is changed. - w.loadWrapper.UpdateClusterAndService(w.clusterName, w.edsServiceName) - } - - if w.lrsServerName == nil || *w.lrsServerName != newConfig.LoadReportingServerName { - // LoadReportingServerName is different, load should be report to a - // different server, restart. - restartLoadReport = true - w.lrsServerName = &newConfig.LoadReportingServerName - } - - if restartLoadReport { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - var loadStore *load.Store - if w.c != nil { - loadStore, w.cancelLoadReport = w.c.ReportLoad(*w.lrsServerName) - } - w.loadWrapper.UpdateLoadStore(loadStore) - } - - return nil -} - -func (w *xdsClientWrapper) loadStore() load.PerClusterReporter { - return w.loadWrapper -} - -func (w *xdsClientWrapper) close() { - if w.cancelLoadReport != nil { - w.cancelLoadReport() - w.cancelLoadReport = nil - } - w.c.Close() -} diff --git a/xds/internal/balancer/lrs/balancer_test.go b/xds/internal/balancer/lrs/balancer_test.go deleted file mode 100644 index f91937385a92..000000000000 --- a/xds/internal/balancer/lrs/balancer_test.go +++ /dev/null @@ -1,146 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2019 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "context" - "fmt" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/roundrobin" - "google.golang.org/grpc/connectivity" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/resolver" - xdsinternal "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/testutils" - "google.golang.org/grpc/xds/internal/testutils/fakeclient" -) - -const defaultTestTimeout = 1 * time.Second - -var ( - testBackendAddrs = []resolver.Address{ - {Addr: "1.1.1.1:1"}, - } - testLocality = &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - } -) - -// TestLoadReporting verifies that the lrs balancer starts the loadReport -// stream when the LBConfig passed to it contains a valid value for the LRS -// server (empty string). -func TestLoadReporting(t *testing.T) { - xdsC := fakeclient.NewClient() - oldNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { return xdsC, nil } - defer func() { newXDSClient = oldNewXDSClient }() - - builder := balancer.Get(Name) - cc := testutils.NewTestClientConn(t) - lrsB := builder.Build(cc, balancer.BuildOptions{}) - defer lrsB.Close() - - if err := lrsB.UpdateClientConnState(balancer.ClientConnState{ - ResolverState: resolver.State{ - Addresses: testBackendAddrs, - }, - BalancerConfig: &LBConfig{ - ClusterName: testClusterName, - EDSServiceName: testServiceName, - LoadReportingServerName: testLRSServerName, - Locality: testLocality, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - }, - }, - }); err != nil { - t.Fatalf("unexpected error from UpdateClientConnState: %v", err) - } - - ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) - defer cancel() - - got, err := xdsC.WaitForReportLoad(ctx) - if err != nil { - t.Fatalf("xdsClient.ReportLoad failed with error: %v", err) - } - if got.Server != testLRSServerName { - t.Fatalf("xdsClient.ReportLoad called with {%q}: want {%q}", got.Server, testLRSServerName) - } - - sc1 := <-cc.NewSubConnCh - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Connecting}) - lrsB.UpdateSubConnState(sc1, balancer.SubConnState{ConnectivityState: connectivity.Ready}) - - // Test pick with one backend. - p1 := <-cc.NewPickerCh - const successCount = 5 - for i := 0; i < successCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{}) - } - const errorCount = 5 - for i := 0; i < errorCount; i++ { - gotSCSt, _ := p1.Pick(balancer.PickInfo{}) - if !cmp.Equal(gotSCSt.SubConn, sc1, cmp.AllowUnexported(testutils.TestSubConn{})) { - t.Fatalf("picker.Pick, got %v, want SubConn=%v", gotSCSt, sc1) - } - gotSCSt.Done(balancer.DoneInfo{Err: fmt.Errorf("error")}) - } - - // Dump load data from the store and compare with expected counts. - loadStore := xdsC.LoadStore() - if loadStore == nil { - t.Fatal("loadStore is nil in xdsClient") - } - sds := loadStore.Stats([]string{testClusterName}) - if len(sds) == 0 { - t.Fatalf("loads for cluster %v not found in store", testClusterName) - } - sd := sds[0] - if sd.Cluster != testClusterName || sd.Service != testServiceName { - t.Fatalf("got unexpected load for %q, %q, want %q, %q", sd.Cluster, sd.Service, testClusterName, testServiceName) - } - testLocalityJSON, _ := testLocality.ToString() - localityData, ok := sd.LocalityStats[testLocalityJSON] - if !ok { - t.Fatalf("loads for %v not found in store", testLocality) - } - reqStats := localityData.RequestStats - if reqStats.Succeeded != successCount { - t.Errorf("got succeeded %v, want %v", reqStats.Succeeded, successCount) - } - if reqStats.Errored != errorCount { - t.Errorf("got errord %v, want %v", reqStats.Errored, errorCount) - } - if reqStats.InProgress != 0 { - t.Errorf("got inProgress %v, want %v", reqStats.InProgress, 0) - } -} diff --git a/xds/internal/balancer/lrs/config.go b/xds/internal/balancer/lrs/config.go deleted file mode 100644 index e0e30bbb8821..000000000000 --- a/xds/internal/balancer/lrs/config.go +++ /dev/null @@ -1,53 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "encoding/json" - "fmt" - - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - "google.golang.org/grpc/serviceconfig" - "google.golang.org/grpc/xds/internal" -) - -// LBConfig is the balancer config for lrs balancer. -type LBConfig struct { - serviceconfig.LoadBalancingConfig `json:"-"` - - ClusterName string `json:"clusterName,omitempty"` - EDSServiceName string `json:"edsServiceName,omitempty"` - LoadReportingServerName string `json:"lrsLoadReportingServerName,omitempty"` - Locality *internal.LocalityID `json:"locality,omitempty"` - ChildPolicy *internalserviceconfig.BalancerConfig `json:"childPolicy,omitempty"` -} - -func parseConfig(c json.RawMessage) (*LBConfig, error) { - var cfg LBConfig - if err := json.Unmarshal(c, &cfg); err != nil { - return nil, err - } - if cfg.ClusterName == "" { - return nil, fmt.Errorf("required ClusterName is not set in %+v", cfg) - } - if cfg.Locality == nil { - return nil, fmt.Errorf("required Locality is not set in %+v", cfg) - } - return &cfg, nil -} diff --git a/xds/internal/balancer/lrs/config_test.go b/xds/internal/balancer/lrs/config_test.go deleted file mode 100644 index eaf902ac535d..000000000000 --- a/xds/internal/balancer/lrs/config_test.go +++ /dev/null @@ -1,114 +0,0 @@ -// +build go1.12 - -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "testing" - - "github.com/google/go-cmp/cmp" - "google.golang.org/grpc/balancer/roundrobin" - internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - xdsinternal "google.golang.org/grpc/xds/internal" -) - -const ( - testClusterName = "test-cluster" - testServiceName = "test-eds-service" - testLRSServerName = "test-lrs-name" -) - -func TestParseConfig(t *testing.T) { - tests := []struct { - name string - js string - want *LBConfig - wantErr bool - }{ - { - name: "no cluster name", - js: `{ - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "no locality", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "childPolicy":[{"round_robin":{}}] -} - `, - wantErr: true, - }, - { - name: "good", - js: `{ - "clusterName": "test-cluster", - "edsServiceName": "test-eds-service", - "lrsLoadReportingServerName": "test-lrs-name", - "locality": { - "region": "test-region", - "zone": "test-zone", - "subZone": "test-sub-zone" - }, - "childPolicy":[{"round_robin":{}}] -} - `, - want: &LBConfig{ - ClusterName: testClusterName, - EDSServiceName: testServiceName, - LoadReportingServerName: testLRSServerName, - Locality: &xdsinternal.LocalityID{ - Region: "test-region", - Zone: "test-zone", - SubZone: "test-sub-zone", - }, - ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: roundrobin.Name, - Config: nil, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := parseConfig([]byte(tt.js)) - if (err != nil) != tt.wantErr { - t.Errorf("parseConfig() error = %v, wantErr %v", err, tt.wantErr) - return - } - if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("parseConfig() got = %v, want %v, diff: %s", got, tt.want, diff) - } - }) - } -} diff --git a/xds/internal/balancer/lrs/logging.go b/xds/internal/balancer/lrs/logging.go deleted file mode 100644 index 602dac099597..000000000000 --- a/xds/internal/balancer/lrs/logging.go +++ /dev/null @@ -1,34 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - "fmt" - - "google.golang.org/grpc/grpclog" - internalgrpclog "google.golang.org/grpc/internal/grpclog" -) - -const prefix = "[lrs-lb %p] " - -var logger = grpclog.Component("xds") - -func prefixLogger(p *lrsBalancer) *internalgrpclog.PrefixLogger { - return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(prefix, p)) -} diff --git a/xds/internal/balancer/lrs/picker.go b/xds/internal/balancer/lrs/picker.go deleted file mode 100644 index 1e4ad156e5b7..000000000000 --- a/xds/internal/balancer/lrs/picker.go +++ /dev/null @@ -1,85 +0,0 @@ -/* - * - * Copyright 2020 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package lrs - -import ( - orcapb "github.com/cncf/udpa/go/udpa/data/orca/v1" - "google.golang.org/grpc/balancer" -) - -const ( - serverLoadCPUName = "cpu_utilization" - serverLoadMemoryName = "mem_utilization" -) - -// loadReporter wraps the methods from the loadStore that are used here. -type loadReporter interface { - CallStarted(locality string) - CallFinished(locality string, err error) - CallServerLoad(locality, name string, val float64) -} - -type loadReportPicker struct { - p balancer.Picker - - locality string - loadStore loadReporter -} - -func newLoadReportPicker(p balancer.Picker, id string, loadStore loadReporter) *loadReportPicker { - return &loadReportPicker{ - p: p, - locality: id, - loadStore: loadStore, - } -} - -func (lrp *loadReportPicker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { - res, err := lrp.p.Pick(info) - if err != nil { - return res, err - } - - if lrp.loadStore == nil { - return res, err - } - - lrp.loadStore.CallStarted(lrp.locality) - oldDone := res.Done - res.Done = func(info balancer.DoneInfo) { - if oldDone != nil { - oldDone(info) - } - lrp.loadStore.CallFinished(lrp.locality, info.Err) - - load, ok := info.ServerLoad.(*orcapb.OrcaLoadReport) - if !ok { - return - } - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadCPUName, load.CpuUtilization) - lrp.loadStore.CallServerLoad(lrp.locality, serverLoadMemoryName, load.MemUtilization) - for n, d := range load.RequestCost { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - for n, d := range load.Utilization { - lrp.loadStore.CallServerLoad(lrp.locality, n, d) - } - } - return res, err -} diff --git a/xds/internal/balancer/priority/balancer.go b/xds/internal/balancer/priority/balancer.go index d760a58fa4ec..7475145c612b 100644 --- a/xds/internal/balancer/priority/balancer.go +++ b/xds/internal/balancer/priority/balancer.go @@ -44,12 +44,12 @@ import ( const Name = "priority_experimental" func init() { - balancer.Register(priorityBB{}) + balancer.Register(bb{}) } -type priorityBB struct{} +type bb struct{} -func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &priorityBalancer{ cc: cc, done: grpcsync.NewEvent(), @@ -66,11 +66,11 @@ func (priorityBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) bal return b } -func (b priorityBB) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (b bb) ParseConfig(s json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(s) } -func (priorityBB) Name() string { +func (bb) Name() string { return Name } diff --git a/xds/internal/balancer/priority/balancer_priority.go b/xds/internal/balancer/priority/balancer_priority.go index 672f17498b16..3ba0b9e929d7 100644 --- a/xds/internal/balancer/priority/balancer_priority.go +++ b/xds/internal/balancer/priority/balancer_priority.go @@ -28,7 +28,8 @@ import ( ) var ( - errAllPrioritiesRemoved = errors.New("no locality is provided, all priorities are removed") + // ErrAllPrioritiesRemoved is returned by the picker when there's no priority available. + ErrAllPrioritiesRemoved = errors.New("no priority is provided, all priorities are removed") // DefaultPriorityInitTimeout is the timeout after which if a priority is // not READY, the next will be started. It's exported to be overridden by // tests. @@ -73,7 +74,7 @@ func (b *priorityBalancer) syncPriority() { b.stopPriorityInitTimer() b.cc.UpdateState(balancer.State{ ConnectivityState: connectivity.TransientFailure, - Picker: base.NewErrPicker(errAllPrioritiesRemoved), + Picker: base.NewErrPicker(ErrAllPrioritiesRemoved), }) return } diff --git a/xds/internal/balancer/priority/balancer_test.go b/xds/internal/balancer/priority/balancer_test.go index 03442671311b..bc43a74557e6 100644 --- a/xds/internal/balancer/priority/balancer_test.go +++ b/xds/internal/balancer/priority/balancer_test.go @@ -828,8 +828,8 @@ func (s) TestPriority_RemovesAllPriorities(t *testing.T) { // Test pick return TransientFailure. pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } @@ -1436,8 +1436,8 @@ func (s) TestPriority_ReadyChildRemovedButInCache(t *testing.T) { pFail := <-cc.NewPickerCh for i := 0; i < 5; i++ { - if _, err := pFail.Pick(balancer.PickInfo{}); err != errAllPrioritiesRemoved { - t.Fatalf("want pick error %v, got %v", errAllPrioritiesRemoved, err) + if _, err := pFail.Pick(balancer.PickInfo{}); err != ErrAllPrioritiesRemoved { + t.Fatalf("want pick error %v, got %v", ErrAllPrioritiesRemoved, err) } } diff --git a/xds/internal/balancer/priority/config.go b/xds/internal/balancer/priority/config.go index c9cb16e323f0..37f1c9a829a8 100644 --- a/xds/internal/balancer/priority/config.go +++ b/xds/internal/balancer/priority/config.go @@ -29,7 +29,7 @@ import ( // Child is a child of priority balancer. type Child struct { Config *internalserviceconfig.BalancerConfig `json:"config,omitempty"` - IgnoreReresolutionRequests bool + IgnoreReresolutionRequests bool `json:"ignoreReresolutionRequests,omitempty"` } // LBConfig represents priority balancer's config. diff --git a/xds/internal/balancer/ringhash/util.go b/xds/internal/balancer/ringhash/util.go new file mode 100644 index 000000000000..848b20844d89 --- /dev/null +++ b/xds/internal/balancer/ringhash/util.go @@ -0,0 +1,41 @@ +/* + * + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package ringhash contains the functionality to support Ring Hash in grpc. +package ringhash + +import "context" + +type clusterKey struct{} + +func getRequestHash(ctx context.Context) uint64 { + requestHash, _ := ctx.Value(clusterKey{}).(uint64) + return requestHash +} + +// GetRequestHashForTesting returns the request hash in the context; to be used +// for testing only. +func GetRequestHashForTesting(ctx context.Context) uint64 { + return getRequestHash(ctx) +} + +// SetRequestHash adds the request hash to the context for use in Ring Hash Load +// Balancing. +func SetRequestHash(ctx context.Context, requestHash uint64) context.Context { + return context.WithValue(ctx, clusterKey{}, requestHash) +} diff --git a/xds/internal/balancer/weightedtarget/weightedtarget.go b/xds/internal/balancer/weightedtarget/weightedtarget.go index 6c1b70f92235..eb6516af56e9 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget.go @@ -42,12 +42,12 @@ const Name = "weighted_target_experimental" var NewRandomWRR = wrr.NewRandom func init() { - balancer.Register(&weightedTargetBB{}) + balancer.Register(bb{}) } -type weightedTargetBB struct{} +type bb struct{} -func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { +func (bb) Build(cc balancer.ClientConn, bOpts balancer.BuildOptions) balancer.Balancer { b := &weightedTargetBalancer{} b.logger = prefixLogger(b) b.stateAggregator = weightedaggregator.New(cc, b.logger, NewRandomWRR) @@ -58,11 +58,11 @@ func (wt *weightedTargetBB) Build(cc balancer.ClientConn, bOpts balancer.BuildOp return b } -func (wt *weightedTargetBB) Name() string { +func (bb) Name() string { return Name } -func (wt *weightedTargetBB) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { +func (bb) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { return parseConfig(c) } @@ -81,10 +81,10 @@ type weightedTargetBalancer struct { } // UpdateClientConnState takes the new targets in balancer group, -// creates/deletes sub-balancers and sends them update. Addresses are split into +// creates/deletes sub-balancers and sends them update. addresses are split into // groups based on hierarchy path. -func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { - w.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) +func (b *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnState) error { + b.logger.Infof("Received update from resolver, balancer config: %+v", pretty.ToJSON(s.BalancerConfig)) newConfig, ok := s.BalancerConfig.(*LBConfig) if !ok { return fmt.Errorf("unexpected balancer config with type: %T", s.BalancerConfig) @@ -94,10 +94,10 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat var rebuildStateAndPicker bool // Remove sub-pickers and sub-balancers that are not in the new config. - for name := range w.targets { + for name := range b.targets { if _, ok := newConfig.Targets[name]; !ok { - w.stateAggregator.Remove(name) - w.bg.Remove(name) + b.stateAggregator.Remove(name) + b.bg.Remove(name) // Trigger a state/picker update, because we don't want `ClientConn` // to pick this sub-balancer anymore. rebuildStateAndPicker = true @@ -110,39 +110,39 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat // // For all sub-balancers, forward the address/balancer config update. for name, newT := range newConfig.Targets { - oldT, ok := w.targets[name] + oldT, ok := b.targets[name] if !ok { // If this is a new sub-balancer, add weights to the picker map. - w.stateAggregator.Add(name, newT.Weight) + b.stateAggregator.Add(name, newT.Weight) // Then add to the balancer group. - w.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) // Not trigger a state/picker update. Wait for the new sub-balancer // to send its updates. } else if newT.ChildPolicy.Name != oldT.ChildPolicy.Name { // If the child policy name is differet, remove from balancer group // and re-add. - w.stateAggregator.Remove(name) - w.bg.Remove(name) - w.stateAggregator.Add(name, newT.Weight) - w.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) + b.stateAggregator.Remove(name) + b.bg.Remove(name) + b.stateAggregator.Add(name, newT.Weight) + b.bg.Add(name, balancer.Get(newT.ChildPolicy.Name)) // Trigger a state/picker update, because we don't want `ClientConn` // to pick this sub-balancer anymore. rebuildStateAndPicker = true } else if newT.Weight != oldT.Weight { // If this is an existing sub-balancer, update weight if necessary. - w.stateAggregator.UpdateWeight(name, newT.Weight) + b.stateAggregator.UpdateWeight(name, newT.Weight) // Trigger a state/picker update, because we don't want `ClientConn` // should do picks with the new weights now. rebuildStateAndPicker = true } // Forwards all the update: - // - Addresses are from the map after splitting with hierarchy path, + // - addresses are from the map after splitting with hierarchy path, // - Top level service config and attributes are the same, // - Balancer config comes from the targets map. // // TODO: handle error? How to aggregate errors and return? - _ = w.bg.UpdateClientConnState(name, balancer.ClientConnState{ + _ = b.bg.UpdateClientConnState(name, balancer.ClientConnState{ ResolverState: resolver.State{ Addresses: addressesSplit[name], ServiceConfig: s.ResolverState.ServiceConfig, @@ -152,23 +152,23 @@ func (w *weightedTargetBalancer) UpdateClientConnState(s balancer.ClientConnStat }) } - w.targets = newConfig.Targets + b.targets = newConfig.Targets if rebuildStateAndPicker { - w.stateAggregator.BuildAndUpdate() + b.stateAggregator.BuildAndUpdate() } return nil } -func (w *weightedTargetBalancer) ResolverError(err error) { - w.bg.ResolverError(err) +func (b *weightedTargetBalancer) ResolverError(err error) { + b.bg.ResolverError(err) } -func (w *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { - w.bg.UpdateSubConnState(sc, state) +func (b *weightedTargetBalancer) UpdateSubConnState(sc balancer.SubConn, state balancer.SubConnState) { + b.bg.UpdateSubConnState(sc, state) } -func (w *weightedTargetBalancer) Close() { - w.stateAggregator.Stop() - w.bg.Close() +func (b *weightedTargetBalancer) Close() { + b.stateAggregator.Stop() + b.bg.Close() } diff --git a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go index 351a13553e4e..3cd6f74df723 100644 --- a/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go +++ b/xds/internal/balancer/weightedtarget/weightedtarget_config_test.go @@ -26,8 +26,7 @@ import ( "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" internalserviceconfig "google.golang.org/grpc/internal/serviceconfig" - - _ "google.golang.org/grpc/xds/internal/balancer/lrs" // Register LRS balancer, so we can use it as child policy in the config tests. + "google.golang.org/grpc/xds/internal/balancer/priority" ) const ( @@ -35,23 +34,22 @@ const ( "targets": { "cluster_1" : { "weight":75, - "childPolicy":[{"lrs_experimental":{"clusterName":"cluster_1","lrsLoadReportingServerName":"lrs.server","locality":{"zone":"test-zone-1"}}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}}] }, "cluster_2" : { "weight":25, - "childPolicy":[{"lrs_experimental":{"clusterName":"cluster_2","lrsLoadReportingServerName":"lrs.server","locality":{"zone":"test-zone-2"}}}] + "childPolicy":[{"priority_experimental":{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}}] } } }` - lrsBalancerName = "lrs_experimental" ) var ( - lrsConfigParser = balancer.Get(lrsBalancerName).(balancer.ConfigParser) - lrsConfigJSON1 = `{"clusterName":"cluster_1","lrsLoadReportingServerName":"lrs.server","locality":{"zone":"test-zone-1"}}` - lrsConfig1, _ = lrsConfigParser.ParseConfig([]byte(lrsConfigJSON1)) - lrsConfigJSON2 = `{"clusterName":"cluster_2","lrsLoadReportingServerName":"lrs.server","locality":{"zone":"test-zone-2"}}` - lrsConfig2, _ = lrsConfigParser.ParseConfig([]byte(lrsConfigJSON2)) + testConfigParser = balancer.Get(priority.Name).(balancer.ConfigParser) + testConfigJSON1 = `{"priorities": ["child-1"], "children": {"child-1": {"config": [{"round_robin":{}}]}}}` + testConfig1, _ = testConfigParser.ParseConfig([]byte(testConfigJSON1)) + testConfigJSON2 = `{"priorities": ["child-2"], "children": {"child-2": {"config": [{"round_robin":{}}]}}}` + testConfig2, _ = testConfigParser.ParseConfig([]byte(testConfigJSON2)) ) func Test_parseConfig(t *testing.T) { @@ -75,15 +73,15 @@ func Test_parseConfig(t *testing.T) { "cluster_1": { Weight: 75, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: lrsBalancerName, - Config: lrsConfig1, + Name: priority.Name, + Config: testConfig1, }, }, "cluster_2": { Weight: 25, ChildPolicy: &internalserviceconfig.BalancerConfig{ - Name: lrsBalancerName, - Config: lrsConfig2, + Name: priority.Name, + Config: testConfig2, }, }, }, diff --git a/xds/internal/client/tests/README.md b/xds/internal/client/tests/README.md deleted file mode 100644 index 6dc940c103f7..000000000000 --- a/xds/internal/client/tests/README.md +++ /dev/null @@ -1 +0,0 @@ -This package contains tests which cannot live in the `client` package because they need to import one of the API client packages (which itself has a dependency on the `client` package). diff --git a/xds/internal/httpfilter/fault/fault.go b/xds/internal/httpfilter/fault/fault.go index ee2ed9fd4922..42f8e70af93b 100644 --- a/xds/internal/httpfilter/fault/fault.go +++ b/xds/internal/httpfilter/fault/fault.go @@ -33,7 +33,6 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpcrand" iresolver "google.golang.org/grpc/internal/resolver" - "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/httpfilter" @@ -63,9 +62,7 @@ var statusMap = map[int]codes.Code{ } func init() { - if env.FaultInjectionSupport { - httpfilter.Register(builder{}) - } + httpfilter.Register(builder{}) } type builder struct { diff --git a/xds/internal/httpfilter/fault/fault_test.go b/xds/internal/httpfilter/fault/fault_test.go index c132e912f92a..a77ab58ad356 100644 --- a/xds/internal/httpfilter/fault/fault_test.go +++ b/xds/internal/httpfilter/fault/fault_test.go @@ -40,10 +40,8 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/xds" - "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/httpfilter" xtestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/e2e" "google.golang.org/protobuf/types/known/wrapperspb" @@ -55,9 +53,9 @@ import ( tpb "github.com/envoyproxy/go-control-plane/envoy/type/v3" testpb "google.golang.org/grpc/test/grpc_testing" - _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. - _ "google.golang.org/grpc/xds/internal/client/v3" // Register the v3 xDS API client. - _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. + _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. ) type s struct { @@ -140,13 +138,6 @@ func clientSetup(t *testing.T) (*e2e.ManagementServer, string, uint32, func()) { } } -func init() { - env.FaultInjectionSupport = true - // Manually register to avoid a race between this init and the one that - // check the env var to register the fault injection filter. - httpfilter.Register(builder{}) -} - func (s) TestFaultInjection_Unary(t *testing.T) { type subcase struct { name string @@ -617,7 +608,7 @@ func (s) TestFaultInjection_MaxActiveFaults(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - streams := make(chan testpb.TestService_FullDuplexCallClient) + streams := make(chan testpb.TestService_FullDuplexCallClient, 5) // startStream() is called 5 times startStream := func() { str, err := client.FullDuplexCall(ctx) if err != nil { @@ -629,7 +620,7 @@ func (s) TestFaultInjection_MaxActiveFaults(t *testing.T) { str := <-streams str.CloseSend() if _, err := str.Recv(); err != io.EOF { - t.Fatal("stream error:", err) + t.Error("stream error:", err) } } releaseStream := func() { diff --git a/xds/internal/internal.go b/xds/internal/internal.go index e4284ee02e0c..0cccd3824101 100644 --- a/xds/internal/internal.go +++ b/xds/internal/internal.go @@ -22,6 +22,8 @@ package internal import ( "encoding/json" "fmt" + + "google.golang.org/grpc/resolver" ) // LocalityID is xds.Locality without XXX fields, so it can be used as map @@ -53,3 +55,19 @@ func LocalityIDFromString(s string) (ret LocalityID, _ error) { } return ret, nil } + +type localityKeyType string + +const localityKey = localityKeyType("grpc.xds.internal.address.locality") + +// GetLocalityID returns the locality ID of addr. +func GetLocalityID(addr resolver.Address) LocalityID { + path, _ := addr.Attributes.Value(localityKey).(LocalityID) + return path +} + +// SetLocalityID sets locality ID in addr to l. +func SetLocalityID(addr resolver.Address, l LocalityID) resolver.Address { + addr.Attributes = addr.Attributes.WithValues(localityKey, l) + return addr +} diff --git a/xds/internal/resolver/matcher.go b/xds/internal/resolver/matcher.go index e329944e1db1..6e09d93afa78 100644 --- a/xds/internal/resolver/matcher.go +++ b/xds/internal/resolver/matcher.go @@ -27,25 +27,25 @@ import ( iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/xds/matcher" "google.golang.org/grpc/metadata" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) func routeToMatcher(r *xdsclient.Route) (*compositeMatcher, error) { - var pathMatcher pathMatcherInterface + var pm pathMatcher switch { case r.Regex != nil: - pathMatcher = newPathRegexMatcher(r.Regex) + pm = newPathRegexMatcher(r.Regex) case r.Path != nil: - pathMatcher = newPathExactMatcher(*r.Path, r.CaseInsensitive) + pm = newPathExactMatcher(*r.Path, r.CaseInsensitive) case r.Prefix != nil: - pathMatcher = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) + pm = newPathPrefixMatcher(*r.Prefix, r.CaseInsensitive) default: return nil, fmt.Errorf("illegal route: missing path_matcher") } - var headerMatchers []matcher.HeaderMatcherInterface + var headerMatchers []matcher.HeaderMatcher for _, h := range r.Headers { - var matcherT matcher.HeaderMatcherInterface + var matcherT matcher.HeaderMatcher switch { case h.ExactMatch != nil && *h.ExactMatch != "": matcherT = matcher.NewHeaderExactMatcher(h.Name, *h.ExactMatch) @@ -72,17 +72,17 @@ func routeToMatcher(r *xdsclient.Route) (*compositeMatcher, error) { if r.Fraction != nil { fractionMatcher = newFractionMatcher(*r.Fraction) } - return newCompositeMatcher(pathMatcher, headerMatchers, fractionMatcher), nil + return newCompositeMatcher(pm, headerMatchers, fractionMatcher), nil } // compositeMatcher.match returns true if all matchers return true. type compositeMatcher struct { - pm pathMatcherInterface - hms []matcher.HeaderMatcherInterface + pm pathMatcher + hms []matcher.HeaderMatcher fm *fractionMatcher } -func newCompositeMatcher(pm pathMatcherInterface, hms []matcher.HeaderMatcherInterface, fm *fractionMatcher) *compositeMatcher { +func newCompositeMatcher(pm pathMatcher, hms []matcher.HeaderMatcher, fm *fractionMatcher) *compositeMatcher { return &compositeMatcher{pm: pm, hms: hms, fm: fm} } diff --git a/xds/internal/resolver/matcher_path.go b/xds/internal/resolver/matcher_path.go index 011d1a94c49c..88a04f6d7bef 100644 --- a/xds/internal/resolver/matcher_path.go +++ b/xds/internal/resolver/matcher_path.go @@ -23,7 +23,7 @@ import ( "strings" ) -type pathMatcherInterface interface { +type pathMatcher interface { match(path string) bool String() string } diff --git a/xds/internal/resolver/matcher_test.go b/xds/internal/resolver/matcher_test.go index 6f599b82da2a..f7d5486cc136 100644 --- a/xds/internal/resolver/matcher_test.go +++ b/xds/internal/resolver/matcher_test.go @@ -34,8 +34,8 @@ import ( func TestAndMatcherMatch(t *testing.T) { tests := []struct { name string - pm pathMatcherInterface - hm matcher.HeaderMatcherInterface + pm pathMatcher + hm matcher.HeaderMatcher info iresolver.RPCInfo want bool }{ @@ -108,7 +108,7 @@ func TestAndMatcherMatch(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - a := newCompositeMatcher(tt.pm, []matcher.HeaderMatcherInterface{tt.hm}, nil) + a := newCompositeMatcher(tt.pm, []matcher.HeaderMatcher{tt.hm}, nil) if got := a.match(tt.info); got != tt.want { t.Errorf("match() = %v, want %v", got, tt.want) } diff --git a/xds/internal/resolver/serviceconfig.go b/xds/internal/resolver/serviceconfig.go index 49ee8970d7d8..9eaff52dbcce 100644 --- a/xds/internal/resolver/serviceconfig.go +++ b/xds/internal/resolver/serviceconfig.go @@ -22,18 +22,24 @@ import ( "context" "encoding/json" "fmt" + "math/bits" + "strings" "sync/atomic" "time" + "github.com/cespare/xxhash" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/grpcrand" iresolver "google.golang.org/grpc/internal/resolver" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal/balancer/clustermanager" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -116,6 +122,7 @@ type route struct { maxStreamDuration time.Duration // map from filter name to its config httpFilterConfigOverride map[string]httpfilter.FilterConfig + hashPolicies []*xdsclient.HashPolicy } func (r route) String() string { @@ -161,9 +168,15 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP return nil, err } + lbCtx := clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name) + // Request Hashes are only applicable for a Ring Hash LB. + if env.RingHashSupport { + lbCtx = ringhash.SetRequestHash(lbCtx, cs.generateHash(rpcInfo, rt.hashPolicies)) + } + config := &iresolver.RPCConfig{ - // Communicate to the LB policy the chosen cluster. - Context: clustermanager.SetPickedCluster(rpcInfo.Context, cluster.name), + // Communicate to the LB policy the chosen cluster and request hash, if Ring Hash LB policy. + Context: lbCtx, OnCommitted: func() { // When the RPC is committed, the cluster is no longer required. // Decrease its ref. @@ -179,13 +192,68 @@ func (cs *configSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*iresolver.RP Interceptor: interceptor, } - if env.TimeoutSupport && rt.maxStreamDuration != 0 { + if rt.maxStreamDuration != 0 { config.MethodConfig.Timeout = &rt.maxStreamDuration } return config, nil } +func (cs *configSelector) generateHash(rpcInfo iresolver.RPCInfo, hashPolicies []*xdsclient.HashPolicy) uint64 { + var hash uint64 + var generatedHash bool + for _, policy := range hashPolicies { + var policyHash uint64 + var generatedPolicyHash bool + switch policy.HashPolicyType { + case xdsclient.HashPolicyTypeHeader: + md, ok := metadata.FromIncomingContext(rpcInfo.Context) + if !ok { + continue + } + values := md.Get(policy.HeaderName) + // If the header isn't present, no-op. + if len(values) == 0 { + continue + } + joinedValues := strings.Join(values, ",") + if policy.Regex != nil { + joinedValues = policy.Regex.ReplaceAllString(joinedValues, policy.RegexSubstitution) + } + policyHash = xxhash.Sum64String(joinedValues) + generatedHash = true + generatedPolicyHash = true + case xdsclient.HashPolicyTypeChannelID: + // Hash the ClientConn pointer which logically uniquely + // identifies the client. + policyHash = xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)) + generatedHash = true + generatedPolicyHash = true + } + + // Deterministically combine the hash policies. Rotating prevents + // duplicate hash policies from cancelling each other out and preserves + // the 64 bits of entropy. + if generatedPolicyHash { + hash = bits.RotateLeft64(hash, 1) + hash = hash ^ policyHash + } + + // If terminal policy and a hash has already been generated, ignore the + // rest of the policies and use that hash already generated. + if policy.Terminal && generatedHash { + break + } + } + + if generatedHash { + return hash + } + // If no generated hash return a random long. In the grand scheme of things + // this logically will map to choosing a random backend to route request to. + return grpcrand.Uint64() +} + func (cs *configSelector) newInterceptor(rt *route, cluster *routeCluster) (iresolver.ClientInterceptor, error) { if len(cs.httpFilterConfig) == 0 { return nil, nil @@ -293,6 +361,7 @@ func (r *xdsResolver) newConfigSelector(su serviceUpdate) (*configSelector, erro } cs.routes[i].httpFilterConfigOverride = rt.HTTPFilterConfigOverride + cs.routes[i].hashPolicies = rt.HashPolicies } // Account for this config selector's clusters. Do this after no further diff --git a/xds/internal/resolver/serviceconfig_test.go b/xds/internal/resolver/serviceconfig_test.go index 7fe8218160f5..b0bc86c882b3 100644 --- a/xds/internal/resolver/serviceconfig_test.go +++ b/xds/internal/resolver/serviceconfig_test.go @@ -21,10 +21,17 @@ package resolver import ( + "context" + "fmt" + "regexp" "testing" + "github.com/cespare/xxhash" "github.com/google/go-cmp/cmp" + iresolver "google.golang.org/grpc/internal/resolver" + "google.golang.org/grpc/metadata" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config + "google.golang.org/grpc/xds/internal/xdsclient" ) func (s) TestPruneActiveClusters(t *testing.T) { @@ -43,3 +50,70 @@ func (s) TestPruneActiveClusters(t *testing.T) { t.Fatalf("r.activeClusters = %v; want %v\nDiffs: %v", r.activeClusters, want, d) } } + +func (s) TestGenerateRequestHash(t *testing.T) { + cs := &configSelector{ + r: &xdsResolver{ + cc: &testClientConn{}, + }, + } + tests := []struct { + name string + hashPolicies []*xdsclient.HashPolicy + requestHashWant uint64 + rpcInfo iresolver.RPCInfo + }{ + // TestGenerateRequestHashHeaders tests generating request hashes for + // hash policies that specify to hash headers. + { + name: "test-generate-request-hash-headers", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), // Will replace /products with /new-products, to test find and replace functionality. + RegexSubstitution: "/new-products", + }}, + requestHashWant: xxhash.Sum64String("/new-products"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewIncomingContext(context.Background(), metadata.Pairs(":path", "/products")), + Method: "/some-method", + }, + }, + // TestGenerateHashChannelID tests generating request hashes for hash + // policies that specify to hash something that uniquely identifies the + // ClientConn (the pointer). + { + name: "test-generate-request-hash-channel-id", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeChannelID, + }}, + requestHashWant: xxhash.Sum64String(fmt.Sprintf("%p", &cs.r.cc)), + rpcInfo: iresolver.RPCInfo{}, + }, + // TestGenerateRequestHashEmptyString tests generating request hashes + // for hash policies that specify to hash headers and replace empty + // strings in the headers. + { + name: "test-generate-request-hash-empty-string", + hashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("") }(), + RegexSubstitution: "e", + }}, + requestHashWant: xxhash.Sum64String("eaebece"), + rpcInfo: iresolver.RPCInfo{ + Context: metadata.NewIncomingContext(context.Background(), metadata.Pairs(":path", "abc")), + Method: "/some-method", + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + requestHashGot := cs.generateHash(test.rpcInfo, test.hashPolicies) + if requestHashGot != test.requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, test.requestHashWant) + } + }) + } +} diff --git a/xds/internal/resolver/watch_service.go b/xds/internal/resolver/watch_service.go index f29fb32832e0..bea5bbcda140 100644 --- a/xds/internal/resolver/watch_service.go +++ b/xds/internal/resolver/watch_service.go @@ -26,7 +26,7 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) // serviceUpdate contains information received from the LDS/RDS responses which @@ -54,7 +54,7 @@ type ldsConfig struct { // Note that during race (e.g. an xDS response is received while the user is // calling cancel()), there's a small window where the callback can be called // after the watcher is canceled. The caller needs to handle this case. -func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { +func watchService(c xdsclient.XDSClient, serviceName string, cb func(serviceUpdate, error), logger *grpclog.PrefixLogger) (cancel func()) { w := &serviceUpdateWatcher{ logger: logger, c: c, @@ -70,7 +70,7 @@ func watchService(c xdsClientInterface, serviceName string, cb func(serviceUpdat // callback at the right time. type serviceUpdateWatcher struct { logger *grpclog.PrefixLogger - c xdsClientInterface + c xdsclient.XDSClient serviceName string ldsCancel func() serviceCb func(serviceUpdate, error) diff --git a/xds/internal/resolver/watch_service_test.go b/xds/internal/resolver/watch_service_test.go index 421e5345a9d2..31c45bf3977f 100644 --- a/xds/internal/resolver/watch_service_test.go +++ b/xds/internal/resolver/watch_service_test.go @@ -29,8 +29,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/proto" ) diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go index e995fa9fa8fe..19ee01773e8c 100644 --- a/xds/internal/resolver/xds_resolver.go +++ b/xds/internal/resolver/xds_resolver.go @@ -27,23 +27,34 @@ import ( "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/xds/internal/client/bootstrap" - iresolver "google.golang.org/grpc/internal/resolver" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient" ) const xdsScheme = "xds" +// NewBuilder creates a new xds resolver builder using a specific xds bootstrap +// config, so tests can use multiple xds clients in different ClientConns at +// the same time. +func NewBuilder(config []byte) (resolver.Builder, error) { + return &xdsResolverBuilder{ + newXDSClient: func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(config) + }, + }, nil +} + // For overriding in unittests. -var newXDSClient = func() (xdsClientInterface, error) { return xdsclient.New() } +var newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } func init() { resolver.Register(&xdsResolverBuilder{}) } -type xdsResolverBuilder struct{} +type xdsResolverBuilder struct { + newXDSClient func() (xdsclient.XDSClient, error) +} // Build helps implement the resolver.Builder interface. // @@ -60,6 +71,11 @@ func (b *xdsResolverBuilder) Build(t resolver.Target, cc resolver.ClientConn, op r.logger = prefixLogger((r)) r.logger.Infof("Creating resolver for target: %+v", t) + newXDSClient := newXDSClient + if b.newXDSClient != nil { + newXDSClient = b.newXDSClient + } + client, err := newXDSClient() if err != nil { return nil, fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -101,15 +117,6 @@ func (*xdsResolverBuilder) Scheme() string { return xdsScheme } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the resolver. This will be faked out in unittests. -type xdsClientInterface interface { - WatchListener(serviceName string, cb func(xdsclient.ListenerUpdate, error)) func() - WatchRouteConfig(routeName string, cb func(xdsclient.RouteConfigUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - // suWithError wraps the ServiceUpdate and error received through a watch API // callback, so that it can pushed onto the update channel as a single entity. type suWithError struct { @@ -131,7 +138,7 @@ type xdsResolver struct { logger *grpclog.PrefixLogger // The underlying xdsClient which performs all xDS requests and responses. - client xdsClientInterface + client xdsclient.XDSClient // A channel for the watch API callback to write service updates on to. The // updates are read by the run goroutine and passed on to the ClientConn. updateCh chan suWithError @@ -178,7 +185,7 @@ func (r *xdsResolver) sendNewServiceConfig(cs *configSelector) bool { state := iresolver.SetConfigSelector(resolver.State{ ServiceConfig: r.cc.ParseServiceConfig(string(sc)), }, cs) - r.cc.UpdateState(state) + r.cc.UpdateState(xdsclient.SetClient(state, r.client)) return true } diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go index eca561f369c5..9bce8ffe8bf6 100644 --- a/xds/internal/resolver/xds_resolver_test.go +++ b/xds/internal/resolver/xds_resolver_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/cespare/xxhash" "github.com/google/go-cmp/cmp" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" @@ -39,18 +40,19 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/internal/wrr" "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" "google.golang.org/grpc/status" _ "google.golang.org/grpc/xds/internal/balancer/cdsbalancer" // To parse LB config "google.golang.org/grpc/xds/internal/balancer/clustermanager" - "google.golang.org/grpc/xds/internal/client" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/balancer/ringhash" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/httpfilter/router" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -115,19 +117,19 @@ func newTestClientConn() *testClientConn { func (s) TestResolverBuilder(t *testing.T) { tests := []struct { name string - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) wantErr bool }{ { name: "simple-good", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return fakeclient.NewClient(), nil }, wantErr: false, }, { name: "newXDSClient-throws-error", - xdsClientFunc: func() (xdsClientInterface, error) { + xdsClientFunc: func() (xdsclient.XDSClient, error) { return nil, errors.New("newXDSClient-throws-error") }, wantErr: true, @@ -168,7 +170,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { // Fake out the xdsClient creation process by providing a fake, which does // not have any certificate provider configuration. oldClientMaker := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { fc := fakeclient.NewClient() fc.SetBootstrapConfig(&bootstrap.Config{}) return fc, nil @@ -195,7 +197,7 @@ func (s) TestResolverBuilder_xdsCredsBootstrapMismatch(t *testing.T) { } type setupOpts struct { - xdsClientFunc func() (xdsClientInterface, error) + xdsClientFunc func() (xdsclient.XDSClient, error) } func testSetup(t *testing.T, opts setupOpts) (*xdsResolver, *testClientConn, func()) { @@ -255,7 +257,7 @@ func waitForWatchRouteConfig(ctx context.Context, t *testing.T, xdsC *fakeclient func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() @@ -272,7 +274,7 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) @@ -282,12 +284,26 @@ func (s) TestXDSResolverWatchCallbackAfterClose(t *testing.T) { } } +// TestXDSResolverCloseClosesXDSClient tests that the XDS resolver's Close +// method closes the XDS client. +func (s) TestXDSResolverCloseClosesXDSClient(t *testing.T) { + xdsC := fakeclient.NewClient() + xdsR, _, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer cancel() + xdsR.Close() + if !xdsC.Closed.HasFired() { + t.Fatalf("xds client not closed by xds resolver Close method") + } +} + // TestXDSResolverBadServiceUpdate tests the case the xdsClient returns a bad // service update. func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -313,7 +329,7 @@ func (s) TestXDSResolverBadServiceUpdate(t *testing.T) { func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -331,7 +347,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters map[string]bool }{ { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, wantJSON: `{"loadBalancingConfig":[{ "xds_cluster_manager_experimental":{ "children":{ @@ -343,7 +359,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"test-cluster-1": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -367,7 +383,7 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { wantClusters: map[string]bool{"cluster_1": true, "cluster_2": true}, }, { - routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "cluster_1": {Weight: 75}, "cluster_2": {Weight: 25}, }}}, @@ -442,12 +458,77 @@ func (s) TestXDSResolverGoodServiceUpdate(t *testing.T) { } } +// TestXDSResolverRequestHash tests a case where a resolver receives a RouteConfig update +// with a HashPolicy specifying to generate a hash. The configSelector generated should +// successfully generate a Hash. +func (s) TestXDSResolverRequestHash(t *testing.T) { + oldRH := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRH }() + + xdsC := fakeclient.NewClient() + xdsR, tcc, cancel := testSetup(t, setupOpts{ + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, + }) + defer xdsR.Close() + defer cancel() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + waitForWatchListener(ctx, t, xdsC, targetStr) + xdsC.InvokeWatchListenerCallback(xdsclient.ListenerUpdate{RouteConfigName: routeStr, HTTPFilters: routerFilterList}, nil) + waitForWatchRouteConfig(ctx, t, xdsC, routeStr) + // Invoke watchAPI callback with a good service update (with hash policies + // specified) and wait for UpdateState method to be called on ClientConn. + xdsC.InvokeWatchRouteConfigCallback(xdsclient.RouteConfigUpdate{ + VirtualHosts: []*xdsclient.VirtualHost{ + { + Domains: []string{targetStr}, + Routes: []*xdsclient.Route{{ + Prefix: newStringP(""), + WeightedClusters: map[string]xdsclient.WeightedCluster{ + "cluster_1": {Weight: 75}, + "cluster_2": {Weight: 25}, + }, + HashPolicies: []*xdsclient.HashPolicy{{ + HashPolicyType: xdsclient.HashPolicyTypeHeader, + HeaderName: ":path", + }}, + }}, + }, + }, + }, nil) + + ctx, cancel = context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + gotState, err := tcc.stateCh.Receive(ctx) + if err != nil { + t.Fatalf("Error waiting for UpdateState to be called: %v", err) + } + rState := gotState.(resolver.State) + cs := iresolver.GetConfigSelector(rState) + if cs == nil { + t.Error("received nil config selector") + } + // Selecting a config when there was a hash policy specified in the route + // that will be selected should put a request hash in the config's context. + res, err := cs.SelectConfig(iresolver.RPCInfo{Context: metadata.NewIncomingContext(context.Background(), metadata.Pairs(":path", "/products"))}) + if err != nil { + t.Fatalf("Unexpected error from cs.SelectConfig(_): %v", err) + } + requestHashGot := ringhash.GetRequestHashForTesting(res.Context) + requestHashWant := xxhash.Sum64String("/products") + if requestHashGot != requestHashWant { + t.Fatalf("requestHashGot = %v, requestHashWant = %v", requestHashGot, requestHashWant) + } +} + // TestXDSResolverRemovedWithRPCs tests the case where a config selector sends // an empty update to the resolver after the resource is removed. func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -464,7 +545,7 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) @@ -507,7 +588,7 @@ func (s) TestXDSResolverRemovedWithRPCs(t *testing.T) { func (s) TestXDSResolverRemovedResource(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer cancel() defer xdsR.Close() @@ -524,7 +605,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) @@ -615,7 +696,7 @@ func (s) TestXDSResolverRemovedResource(t *testing.T) { func (s) TestXDSResolverWRR(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -635,7 +716,7 @@ func (s) TestXDSResolverWRR(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 5}, "B": {Weight: 10}, }}}, @@ -673,10 +754,9 @@ func (s) TestXDSResolverWRR(t *testing.T) { } func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { - defer func(old bool) { env.TimeoutSupport = old }(env.TimeoutSupport) xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -696,7 +776,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("/foo"), WeightedClusters: map[string]xdsclient.WeightedCluster{"A": {Weight: 1}}, MaxStreamDuration: newDurationP(5 * time.Second), @@ -727,35 +807,25 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { } testCases := []struct { - name string - method string - timeoutSupport bool - want *time.Duration + name string + method string + want *time.Duration }{{ - name: "RDS setting", - method: "/foo/method", - timeoutSupport: true, - want: newDurationP(5 * time.Second), - }, { - name: "timeout support disabled", - method: "/foo/method", - timeoutSupport: false, - want: nil, + name: "RDS setting", + method: "/foo/method", + want: newDurationP(5 * time.Second), }, { - name: "explicit zero in RDS; ignore LDS", - method: "/bar/method", - timeoutSupport: true, - want: nil, + name: "explicit zero in RDS; ignore LDS", + method: "/bar/method", + want: nil, }, { - name: "no config in RDS; fallback to LDS", - method: "/baz/method", - timeoutSupport: true, - want: newDurationP(time.Second), + name: "no config in RDS; fallback to LDS", + method: "/baz/method", + want: newDurationP(time.Second), }} for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - env.TimeoutSupport = tc.timeoutSupport req := iresolver.RPCInfo{ Method: tc.method, Context: context.Background(), @@ -779,7 +849,7 @@ func (s) TestXDSResolverMaxStreamDuration(t *testing.T) { func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -796,7 +866,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"test-cluster-1": {Weight: 1}}}}, }, }, }, nil) @@ -846,7 +916,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) @@ -856,7 +926,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) @@ -895,7 +965,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{"NEW": {Weight: 1}}}}, }, }, }, nil) @@ -928,7 +998,7 @@ func (s) TestXDSResolverDelayedOnCommitted(t *testing.T) { func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -954,7 +1024,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, + Routes: []*xdsclient.Route{{Prefix: newStringP(""), WeightedClusters: map[string]xdsclient.WeightedCluster{cluster: {Weight: 1}}}}, }, }, }, nil) @@ -982,7 +1052,7 @@ func (s) TestXDSResolverGoodUpdateAfterError(t *testing.T) { func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -1028,7 +1098,7 @@ func (s) TestXDSResolverResourceNotFoundError(t *testing.T) { func (s) TestXDSResolverMultipleLDSUpdates(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -1203,7 +1273,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { t.Run(tc.name, func(t *testing.T) { xdsC := fakeclient.NewClient() xdsR, tcc, cancel := testSetup(t, setupOpts{ - xdsClientFunc: func() (xdsClientInterface, error) { return xdsC, nil }, + xdsClientFunc: func() (xdsclient.XDSClient, error) { return xdsC, nil }, }) defer xdsR.Close() defer cancel() @@ -1229,7 +1299,7 @@ func (s) TestXDSResolverHTTPFilters(t *testing.T) { VirtualHosts: []*xdsclient.VirtualHost{ { Domains: []string{targetStr}, - Routes: []*client.Route{{ + Routes: []*xdsclient.Route{{ Prefix: newStringP("1"), WeightedClusters: map[string]xdsclient.WeightedCluster{ "A": {Weight: 1}, "B": {Weight: 1}, diff --git a/xds/internal/server/conn_wrapper.go b/xds/internal/server/conn_wrapper.go index a02d75b21445..43be4673655a 100644 --- a/xds/internal/server/conn_wrapper.go +++ b/xds/internal/server/conn_wrapper.go @@ -27,7 +27,7 @@ import ( "google.golang.org/grpc/credentials/tls/certprovider" xdsinternal "google.golang.org/grpc/internal/credentials/xds" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) // connWrapper is a thin wrapper around a net.Conn returned by Accept(). It @@ -102,7 +102,7 @@ func (c *connWrapper) XDSHandshakeInfo() (*xdsinternal.HandshakeInfo, error) { cpc := c.parent.xdsC.BootstrapConfig().CertProviderConfigs // Identity provider name is mandatory on the server-side, and this is - // enforced when the resource is received at the xdsClient layer. + // enforced when the resource is received at the XDSClient layer. secCfg := c.filterChain.SecurityCfg ip, err := buildProviderFunc(cpc, secCfg.IdentityInstanceName, secCfg.IdentityCertName, true, false) if err != nil { diff --git a/xds/internal/server/listener_wrapper.go b/xds/internal/server/listener_wrapper.go index 17f31f28f576..727e95b94f13 100644 --- a/xds/internal/server/listener_wrapper.go +++ b/xds/internal/server/listener_wrapper.go @@ -31,8 +31,8 @@ import ( internalbackoff "google.golang.org/grpc/internal/backoff" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) var ( @@ -88,9 +88,9 @@ func prefixLogger(p *listenerWrapper) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf("[xds-server-listener %p] ", p)) } -// XDSClientInterface wraps the methods on the xdsClient which are required by +// XDSClient wraps the methods on the XDSClient which are required by // the listenerWrapper. -type XDSClientInterface interface { +type XDSClient interface { WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() BootstrapConfig() *bootstrap.Config } @@ -104,8 +104,8 @@ type ListenerWrapperParams struct { // XDSCredsInUse specifies whether or not the user expressed interest to // receive security configuration from the control plane. XDSCredsInUse bool - // XDSClient provides the functionality from the xdsClient required here. - XDSClient XDSClientInterface + // XDSClient provides the functionality from the XDSClient required here. + XDSClient XDSClient // ModeCallback is the callback to invoke when the serving mode changes. ModeCallback ServingModeCallback } @@ -152,7 +152,7 @@ type listenerWrapper struct { name string xdsCredsInUse bool - xdsC XDSClientInterface + xdsC XDSClient cancelWatch func() modeCallback ServingModeCallback @@ -168,7 +168,7 @@ type listenerWrapper struct { // instead of a vanilla channel simplifies the update handler as it need not // keep track of whether the received update is the first one or not. goodUpdate *grpcsync.Event - // A small race exists in the xdsClient code between the receipt of an xDS + // A small race exists in the XDSClient code between the receipt of an xDS // response and the user cancelling the associated watch. In this window, // the registered callback may be invoked after the watch is canceled, and // the user is expected to work around this. This event signifies that the @@ -299,14 +299,14 @@ func (l *listenerWrapper) handleListenerUpdate(update xdsclient.ListenerUpdate, // Make sure that the socket address on the received Listener resource // matches the address of the net.Listener passed to us by the user. This - // check is done here instead of at the xdsClient layer because of the + // check is done here instead of at the XDSClient layer because of the // following couple of reasons: - // - xdsClient cannot know the listening address of every listener in the + // - XDSClient cannot know the listening address of every listener in the // system, and hence cannot perform this check. // - this is a very context-dependent check and only the server has the // appropriate context to perform this check. // - // What this means is that the xdsClient has ACKed a resource which can push + // What this means is that the XDSClient has ACKed a resource which can push // the server into a "not serving" mode. This is not ideal, but this is // what we have decided to do. See gRPC A36 for more details. ilc := update.InboundListenerCfg diff --git a/xds/internal/server/listener_wrapper_test.go b/xds/internal/server/listener_wrapper_test.go index b22f647a93ca..8a79e2321dd9 100644 --- a/xds/internal/server/listener_wrapper_test.go +++ b/xds/internal/server/listener_wrapper_test.go @@ -30,12 +30,13 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" wrapperspb "github.com/golang/protobuf/ptypes/wrappers" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( @@ -82,6 +83,14 @@ var listenerWithFilterChains = &v3listenerpb.Listener{ }), }, }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{}), + }, + }, + }, }, }, } @@ -117,7 +126,10 @@ type fakeListener struct { } func (fl *fakeListener) Accept() (net.Conn, error) { - cne := <-fl.acceptCh + cne, ok := <-fl.acceptCh + if !ok { + return nil, errors.New("a non-temporary error") + } return cne.conn, cne.err } @@ -156,7 +168,7 @@ func (fc *fakeConn) Close() error { func newListenerWrapper(t *testing.T) (*listenerWrapper, <-chan struct{}, *fakeclient.Client, *fakeListener, func()) { t.Helper() - // Create a listener wrapper with a fake listener and fake xdsClient and + // Create a listener wrapper with a fake listener and fake XDSClient and // verify that it extracts the host and port from the passed in listener. lis := &fakeListener{ acceptCh: make(chan connAndErr, 1), @@ -262,6 +274,7 @@ func (s) TestListenerWrapper_Accept(t *testing.T) { }}, nil) ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() + defer close(lis.acceptCh) select { case <-ctx.Done(): t.Fatalf("timeout waiting for the ready channel to be written to after receipt of a good Listener update") diff --git a/xds/internal/test/xds_client_integration_test.go b/xds/internal/test/xds_client_integration_test.go index 713331b325e0..e94b70c9fb64 100644 --- a/xds/internal/test/xds_client_integration_test.go +++ b/xds/internal/test/xds_client_integration_test.go @@ -68,7 +68,7 @@ func (s) TestClientSideXDS(t *testing.T) { port, cleanup := clientSetup(t) defer cleanup() - serviceName := xdsServiceName + "-client-side-xds" + const serviceName = "my-service-client-side-xds" resources := e2e.DefaultClientResources(e2e.ResourceParams{ DialTarget: serviceName, NodeID: xdsClientNodeID, @@ -81,7 +81,7 @@ func (s) TestClientSideXDS(t *testing.T) { } // Create a ClientConn and make a successful RPC. - cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials())) + cc, err := grpc.Dial(fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithResolvers(xdsResolverBuilder)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } diff --git a/xds/internal/test/xds_integration_test.go b/xds/internal/test/xds_integration_test.go index a41fec929762..b66cdd59cafe 100644 --- a/xds/internal/test/xds_integration_test.go +++ b/xds/internal/test/xds_integration_test.go @@ -40,7 +40,9 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/xds/env" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/testdata" + "google.golang.org/grpc/xds" "google.golang.org/grpc/xds/internal/testutils/e2e" xdsinternal "google.golang.org/grpc/internal/xds" @@ -71,8 +73,10 @@ func (*testService) EmptyCall(context.Context, *testpb.Empty) (*testpb.Empty, er var ( // Globals corresponding to the single instance of the xDS management server // which is spawned for all the tests in this package. - managementServer *e2e.ManagementServer - xdsClientNodeID string + managementServer *e2e.ManagementServer + xdsClientNodeID string + bootstrapContents []byte + xdsResolverBuilder resolver.Builder ) // TestMain sets up an xDS management server, runs all tests, and stops the @@ -158,30 +162,33 @@ func createClientTLSCredentials(t *testing.T) credentials.TransportCredentials { // - sets up the global variables which refer to this management server and the // nodeID to be used when talking to this management server. // -// Returns a function to be invoked by the caller to stop the management server. -func setupManagementServer() (func(), error) { +// Returns a function to be invoked by the caller to stop the management +// server. +func setupManagementServer() (cleanup func(), err error) { // Turn on the env var protection for client-side security. origClientSideSecurityEnvVar := env.ClientSideSecuritySupport env.ClientSideSecuritySupport = true // Spin up an xDS management server on a local port. - var err error managementServer, err = e2e.StartManagementServer() if err != nil { return nil, err } + defer func() { + if err != nil { + managementServer.Stop() + } + }() // Create a directory to hold certs and key files used on the server side. serverDir, err := createTmpDirWithFiles("testServerSideXDS*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") if err != nil { - managementServer.Stop() return nil, err } // Create a directory to hold certs and key files used on the client side. clientDir, err := createTmpDirWithFiles("testClientSideXDS*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") if err != nil { - managementServer.Stop() return nil, err } @@ -194,7 +201,7 @@ func setupManagementServer() (func(), error) { // Create a bootstrap file in a temporary directory. xdsClientNodeID = uuid.New().String() - bootstrapCleanup, err := xdsinternal.SetupBootstrapFile(xdsinternal.BootstrapOptions{ + bootstrapContents, err = xdsinternal.BootstrapContents(xdsinternal.BootstrapOptions{ Version: xdsinternal.TransportV3, NodeID: xdsClientNodeID, ServerURI: managementServer.Address, @@ -202,13 +209,15 @@ func setupManagementServer() (func(), error) { ServerListenerResourceNameTemplate: e2e.ServerListenerResourceNameTemplate, }) if err != nil { - managementServer.Stop() + return nil, err + } + xdsResolverBuilder, err = xds.NewXDSResolverWithConfigForTesting(bootstrapContents) + if err != nil { return nil, err } return func() { managementServer.Stop() - bootstrapCleanup() env.ClientSideSecuritySupport = origClientSideSecurityEnvVar }, nil } diff --git a/xds/internal/test/xds_server_integration_test.go b/xds/internal/test/xds_server_integration_test.go index 6511a6134cf8..4bf13e9305be 100644 --- a/xds/internal/test/xds_server_integration_test.go +++ b/xds/internal/test/xds_server_integration_test.go @@ -46,8 +46,6 @@ const ( certFile = "cert.pem" keyFile = "key.pem" rootFile = "ca.pem" - - xdsServiceName = "my-service" ) // setupGRPCServer performs the following: @@ -70,7 +68,7 @@ func setupGRPCServer(t *testing.T) (net.Listener, func()) { } // Initialize an xDS-enabled gRPC server and register the stubServer on it. - server := xds.NewGRPCServer(grpc.Creds(creds)) + server := xds.NewGRPCServer(grpc.Creds(creds), xds.BootstrapContentsForTesting(bootstrapContents)) testpb.RegisterTestServiceServer(server, &testService{}) // Create a local listener and pass it to Serve(). @@ -123,7 +121,7 @@ func (s) TestServerSideXDS_Fallback(t *testing.T) { if err != nil { t.Fatalf("failed to retrieve host and port of server: %v", err) } - serviceName := xdsServiceName + "-fallback" + const serviceName = "my-service-fallback" resources := e2e.DefaultClientResources(e2e.ResourceParams{ DialTarget: serviceName, NodeID: xdsClientNodeID, @@ -154,7 +152,7 @@ func (s) TestServerSideXDS_Fallback(t *testing.T) { // Create a ClientConn with the xds scheme and make a successful RPC. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds)) + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(xdsResolverBuilder)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } @@ -205,7 +203,7 @@ func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { // Create xDS resources to be consumed on the client side. This // includes the listener, route configuration, cluster (with // security configuration) and endpoint resources. - serviceName := xdsServiceName + "-file-watcher-certs-" + test.name + serviceName := "my-service-file-watcher-certs-" + test.name resources := e2e.DefaultClientResources(e2e.ResourceParams{ DialTarget: serviceName, NodeID: xdsClientNodeID, @@ -236,7 +234,7 @@ func (s) TestServerSideXDS_FileWatcherCerts(t *testing.T) { // Create a ClientConn with the xds scheme and make an RPC. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds)) + cc, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(creds), grpc.WithResolvers(xdsResolverBuilder)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } @@ -270,7 +268,7 @@ func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { if err != nil { t.Fatalf("failed to retrieve host and port of server: %v", err) } - serviceName := xdsServiceName + "-security-config-change" + const serviceName = "my-service-security-config-change" resources := e2e.DefaultClientResources(e2e.ResourceParams{ DialTarget: serviceName, NodeID: xdsClientNodeID, @@ -301,7 +299,7 @@ func (s) TestServerSideXDS_SecurityConfigChange(t *testing.T) { // Create a ClientConn with the xds scheme and make a successful RPC. ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() - xdsCC, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(xdsCreds)) + xdsCC, err := grpc.DialContext(ctx, fmt.Sprintf("xds:///%s", serviceName), grpc.WithTransportCredentials(xdsCreds), grpc.WithResolvers(xdsResolverBuilder)) if err != nil { t.Fatalf("failed to dial local test server: %v", err) } diff --git a/xds/internal/test/xds_server_serving_mode_test.go b/xds/internal/test/xds_server_serving_mode_test.go index 0b70f7b06aed..8fb346298abf 100644 --- a/xds/internal/test/xds_server_serving_mode_test.go +++ b/xds/internal/test/xds_server_serving_mode_test.go @@ -105,7 +105,7 @@ func (s) TestServerSideXDS_ServingModeChanges(t *testing.T) { }) // Initialize an xDS-enabled gRPC server and register the stubServer on it. - server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt) + server := xds.NewGRPCServer(grpc.Creds(creds), modeChangeOpt, xds.BootstrapContentsForTesting(bootstrapContents)) defer server.Stop() testpb.RegisterTestServiceServer(server, &testService{}) diff --git a/xds/internal/testutils/fakeclient/client.go b/xds/internal/testutils/fakeclient/client.go index f85f0ecbf3f3..255454080360 100644 --- a/xds/internal/testutils/fakeclient/client.go +++ b/xds/internal/testutils/fakeclient/client.go @@ -22,15 +22,21 @@ package fakeclient import ( "context" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // Client is a fake implementation of an xds client. It exposes a bunch of // channels to signal the occurrence of various events. type Client struct { + // Embed XDSClient so this fake client implements the interface, but it's + // never set (it's always nil). This may cause nil panic since not all the + // methods are implemented. + xdsclient.XDSClient + name string ldsWatchCh *testutils.Channel rdsWatchCh *testutils.Channel @@ -41,14 +47,16 @@ type Client struct { cdsCancelCh *testutils.Channel edsCancelCh *testutils.Channel loadReportCh *testutils.Channel - closeCh *testutils.Channel + lrsCancelCh *testutils.Channel loadStore *load.Store bootstrapCfg *bootstrap.Config ldsCb func(xdsclient.ListenerUpdate, error) rdsCb func(xdsclient.RouteConfigUpdate, error) cdsCbs map[string]func(xdsclient.ClusterUpdate, error) - edsCb func(xdsclient.EndpointsUpdate, error) + edsCbs map[string]func(xdsclient.EndpointsUpdate, error) + + Closed *grpcsync.Event // fired when Close is called. } // WatchListener registers a LDS watch. @@ -172,10 +180,10 @@ func (xdsC *Client) WaitForCancelClusterWatch(ctx context.Context) (string, erro // WatchEndpoints registers an EDS watch for provided clusterName. func (xdsC *Client) WatchEndpoints(clusterName string, callback func(xdsclient.EndpointsUpdate, error)) (cancel func()) { - xdsC.edsCb = callback + xdsC.edsCbs[clusterName] = callback xdsC.edsWatchCh.Send(clusterName) return func() { - xdsC.edsCancelCh.Send(nil) + xdsC.edsCancelCh.Send(clusterName) } } @@ -193,15 +201,28 @@ func (xdsC *Client) WaitForWatchEDS(ctx context.Context) (string, error) { // // Not thread safe with WatchEndpoints. Only call this after // WaitForWatchEDS. -func (xdsC *Client) InvokeWatchEDSCallback(update xdsclient.EndpointsUpdate, err error) { - xdsC.edsCb(update, err) +func (xdsC *Client) InvokeWatchEDSCallback(name string, update xdsclient.EndpointsUpdate, err error) { + if len(xdsC.edsCbs) != 1 { + // This may panic if name isn't found. But it's fine for tests. + xdsC.edsCbs[name](update, err) + return + } + // Keeps functionality with previous usage of this, if single callback call + // that callback. + for n := range xdsC.edsCbs { + name = n + } + xdsC.edsCbs[name](update, err) } // WaitForCancelEDSWatch waits for a EDS watch to be cancelled and returns // context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) error { - _, err := xdsC.edsCancelCh.Receive(ctx) - return err +func (xdsC *Client) WaitForCancelEDSWatch(ctx context.Context) (string, error) { + edsNameReceived, err := xdsC.edsCancelCh.Receive(ctx) + if err != nil { + return "", err + } + return edsNameReceived.(string), err } // ReportLoadArgs wraps the arguments passed to ReportLoad. @@ -213,7 +234,16 @@ type ReportLoadArgs struct { // ReportLoad starts reporting load about clusterName to server. func (xdsC *Client) ReportLoad(server string) (loadStore *load.Store, cancel func()) { xdsC.loadReportCh.Send(ReportLoadArgs{Server: server}) - return xdsC.loadStore, func() {} + return xdsC.loadStore, func() { + xdsC.lrsCancelCh.Send(nil) + } +} + +// WaitForCancelReportLoad waits for a load report to be cancelled and returns +// context.DeadlineExceeded otherwise. +func (xdsC *Client) WaitForCancelReportLoad(ctx context.Context) error { + _, err := xdsC.lrsCancelCh.Receive(ctx) + return err } // LoadStore returns the underlying load data store. @@ -225,19 +255,15 @@ func (xdsC *Client) LoadStore() *load.Store { // returns the arguments passed to it. func (xdsC *Client) WaitForReportLoad(ctx context.Context) (ReportLoadArgs, error) { val, err := xdsC.loadReportCh.Receive(ctx) - return val.(ReportLoadArgs), err + if err != nil { + return ReportLoadArgs{}, err + } + return val.(ReportLoadArgs), nil } -// Close closes the xds client. +// Close fires xdsC.Closed, indicating it was called. func (xdsC *Client) Close() { - xdsC.closeCh.Send(nil) -} - -// WaitForClose waits for Close to be invoked on this client and returns -// context.DeadlineExceeded otherwise. -func (xdsC *Client) WaitForClose(ctx context.Context) error { - _, err := xdsC.closeCh.Receive(ctx) - return err + xdsC.Closed.Fire() } // BootstrapConfig returns the bootstrap config. @@ -269,14 +295,16 @@ func NewClientWithName(name string) *Client { ldsWatchCh: testutils.NewChannel(), rdsWatchCh: testutils.NewChannel(), cdsWatchCh: testutils.NewChannelWithSize(10), - edsWatchCh: testutils.NewChannel(), + edsWatchCh: testutils.NewChannelWithSize(10), ldsCancelCh: testutils.NewChannel(), rdsCancelCh: testutils.NewChannel(), cdsCancelCh: testutils.NewChannelWithSize(10), - edsCancelCh: testutils.NewChannel(), + edsCancelCh: testutils.NewChannelWithSize(10), loadReportCh: testutils.NewChannel(), - closeCh: testutils.NewChannel(), + lrsCancelCh: testutils.NewChannel(), loadStore: load.NewStore(), cdsCbs: make(map[string]func(xdsclient.ClusterUpdate, error)), + edsCbs: make(map[string]func(xdsclient.EndpointsUpdate, error)), + Closed: grpcsync.NewEvent(), } } diff --git a/xds/internal/xdsclient/attributes.go b/xds/internal/xdsclient/attributes.go new file mode 100644 index 000000000000..d2357df0727c --- /dev/null +++ b/xds/internal/xdsclient/attributes.go @@ -0,0 +1,59 @@ +/* + * Copyright 2021 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package xdsclient + +import ( + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/load" +) + +type clientKeyType string + +const clientKey = clientKeyType("grpc.xds.internal.client.Client") + +// XDSClient is a full fledged gRPC client which queries a set of discovery APIs +// (collectively termed as xDS) on a remote management server, to discover +// various dynamic resources. +type XDSClient interface { + WatchListener(string, func(ListenerUpdate, error)) func() + WatchRouteConfig(string, func(RouteConfigUpdate, error)) func() + WatchCluster(string, func(ClusterUpdate, error)) func() + WatchEndpoints(clusterName string, edsCb func(EndpointsUpdate, error)) (cancel func()) + ReportLoad(server string) (*load.Store, func()) + + DumpLDS() (string, map[string]UpdateWithMD) + DumpRDS() (string, map[string]UpdateWithMD) + DumpCDS() (string, map[string]UpdateWithMD) + DumpEDS() (string, map[string]UpdateWithMD) + + BootstrapConfig() *bootstrap.Config + Close() +} + +// FromResolverState returns the Client from state, or nil if not present. +func FromResolverState(state resolver.State) XDSClient { + cs, _ := state.Attributes.Value(clientKey).(XDSClient) + return cs +} + +// SetClient sets c in state and returns the new state. +func SetClient(state resolver.State, c XDSClient) resolver.State { + state.Attributes = state.Attributes.WithValues(clientKey, c) + return state +} diff --git a/xds/internal/client/bootstrap/bootstrap.go b/xds/internal/xdsclient/bootstrap/bootstrap.go similarity index 94% rename from xds/internal/client/bootstrap/bootstrap.go rename to xds/internal/xdsclient/bootstrap/bootstrap.go index a3fb5c0816b4..fa229d99593e 100644 --- a/xds/internal/client/bootstrap/bootstrap.go +++ b/xds/internal/xdsclient/bootstrap/bootstrap.go @@ -127,16 +127,18 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // // The format of the bootstrap file will be as follows: // { -// "xds_server": { -// "server_uri": , -// "channel_creds": [ -// { -// "type": , -// "config": -// } -// ], -// "server_features": [ ... ], -// }, +// "xds_servers": [ +// { +// "server_uri": , +// "channel_creds": [ +// { +// "type": , +// "config": +// } +// ], +// "server_features": [ ... ], +// } +// ], // "node": , // "certificate_providers" : { // "default": { @@ -160,13 +162,19 @@ func bootstrapConfigFromEnvVariable() ([]byte, error) { // fields left unspecified, in which case the caller should use some sane // defaults. func NewConfig() (*Config, error) { - config := &Config{} - data, err := bootstrapConfigFromEnvVariable() if err != nil { return nil, fmt.Errorf("xds: Failed to read bootstrap config: %v", err) } logger.Debugf("Bootstrap content: %s", data) + return NewConfigFromContents(data) +} + +// NewConfigFromContents returns a new Config using the specified bootstrap +// file contents instead of reading the environment variable. This is only +// suitable for testing purposes. +func NewConfigFromContents(data []byte) (*Config, error) { + config := &Config{} var jsonData map[string]json.RawMessage if err := json.Unmarshal(data, &jsonData); err != nil { diff --git a/xds/internal/client/bootstrap/bootstrap_test.go b/xds/internal/xdsclient/bootstrap/bootstrap_test.go similarity index 100% rename from xds/internal/client/bootstrap/bootstrap_test.go rename to xds/internal/xdsclient/bootstrap/bootstrap_test.go diff --git a/xds/internal/client/bootstrap/logging.go b/xds/internal/xdsclient/bootstrap/logging.go similarity index 100% rename from xds/internal/client/bootstrap/logging.go rename to xds/internal/xdsclient/bootstrap/logging.go diff --git a/xds/internal/client/callback.go b/xds/internal/xdsclient/callback.go similarity index 90% rename from xds/internal/client/callback.go rename to xds/internal/xdsclient/callback.go index ac6f0151a5f9..b8ad0ec76362 100644 --- a/xds/internal/client/callback.go +++ b/xds/internal/xdsclient/callback.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import "google.golang.org/grpc/internal/pretty" @@ -84,14 +84,16 @@ func (c *clientImpl) NewListeners(updates map[string]ListenerUpdate, metadata Up // On NACK, update overall version to the NACKed resp. c.ldsVersion = metadata.ErrState.Version for name := range updates { - if _, ok := c.ldsWatchers[name]; ok { + if s, ok := c.ldsWatchers[name]; ok { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.ldsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.ldsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(metadata.ErrState.Err) + } } } return @@ -143,14 +145,16 @@ func (c *clientImpl) NewRouteConfigs(updates map[string]RouteConfigUpdate, metad // On NACK, update overall version to the NACKed resp. c.rdsVersion = metadata.ErrState.Version for name := range updates { - if _, ok := c.rdsWatchers[name]; ok { + if s, ok := c.rdsWatchers[name]; ok { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.rdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.rdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(metadata.ErrState.Err) + } } } return @@ -185,14 +189,16 @@ func (c *clientImpl) NewClusters(updates map[string]ClusterUpdate, metadata Upda // On NACK, update overall version to the NACKed resp. c.cdsVersion = metadata.ErrState.Version for name := range updates { - if _, ok := c.cdsWatchers[name]; ok { + if s, ok := c.cdsWatchers[name]; ok { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.cdsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.cdsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(metadata.ErrState.Err) + } } } return @@ -244,14 +250,16 @@ func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdate, metadata U // On NACK, update overall version to the NACKed resp. c.edsVersion = metadata.ErrState.Version for name := range updates { - if _, ok := c.edsWatchers[name]; ok { + if s, ok := c.edsWatchers[name]; ok { // On error, keep previous version for each resource. But update // status and error. mdCopy := c.edsMD[name] mdCopy.ErrState = metadata.ErrState mdCopy.Status = metadata.Status c.edsMD[name] = mdCopy - // TODO: send the NACK error to the watcher. + for wi := range s { + wi.newError(metadata.ErrState.Err) + } } } return @@ -272,3 +280,16 @@ func (c *clientImpl) NewEndpoints(updates map[string]EndpointsUpdate, metadata U } } } + +// NewConnectionError is called by the underlying xdsAPIClient when it receives +// a connection error. The error will be forwarded to all the resource watchers. +func (c *clientImpl) NewConnectionError(err error) { + c.mu.Lock() + defer c.mu.Unlock() + + for _, s := range c.edsWatchers { + for wi := range s { + wi.newError(NewErrorf(ErrorTypeConnection, "xds: error received from xDS stream: %v", err)) + } + } +} diff --git a/xds/internal/client/cds_test.go b/xds/internal/xdsclient/cds_test.go similarity index 99% rename from xds/internal/client/cds_test.go rename to xds/internal/xdsclient/cds_test.go index 82b92372afc3..88bfe21a7bdb 100644 --- a/xds/internal/client/cds_test.go +++ b/xds/internal/xdsclient/cds_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "regexp" @@ -303,9 +303,6 @@ func (s) TestValidateCluster_Success(t *testing.T) { }, } - origCircuitBreakingSupport := env.CircuitBreakingSupport - env.CircuitBreakingSupport = true - defer func() { env.CircuitBreakingSupport = origCircuitBreakingSupport }() oldAggregateAndDNSSupportEnv := env.AggregateAndDNSSupportEnv env.AggregateAndDNSSupportEnv = true defer func() { env.AggregateAndDNSSupportEnv = oldAggregateAndDNSSupportEnv }() diff --git a/xds/internal/client/client.go b/xds/internal/xdsclient/client.go similarity index 95% rename from xds/internal/client/client.go rename to xds/internal/xdsclient/client.go index 4ed1c3a5833c..a7de226cd292 100644 --- a/xds/internal/client/client.go +++ b/xds/internal/xdsclient/client.go @@ -16,9 +16,9 @@ * */ -// Package client implements a full fledged gRPC client for the xDS API used by -// the xds resolver and balancer implementations. -package client +// Package xdsclient implements a full fledged gRPC client for the xDS API used +// by the xds resolver and balancer implementations. +package xdsclient import ( "context" @@ -34,8 +34,8 @@ import ( "google.golang.org/protobuf/types/known/anypb" "google.golang.org/grpc/internal/xds/matcher" - "google.golang.org/grpc/xds/internal/client/load" "google.golang.org/grpc/xds/internal/httpfilter" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/backoff" @@ -44,8 +44,8 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/xds/internal" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) var ( @@ -141,6 +141,9 @@ type UpdateHandler interface { // NewEndpoints handles updates to xDS ClusterLoadAssignment (or tersely // referred to as Endpoints) resources. NewEndpoints(map[string]EndpointsUpdate, UpdateMetadata) + // NewConnectionError handles connection errors from the xDS stream. The + // error will be reported to all the resource watchers. + NewConnectionError(err error) } // ServiceStatus is the status of the update. @@ -269,6 +272,28 @@ type VirtualHost struct { HTTPFilterConfigOverride map[string]httpfilter.FilterConfig } +// HashPolicyType specifies the type of HashPolicy from a received RDS Response. +type HashPolicyType int + +const ( + // HashPolicyTypeHeader specifies to hash a Header in the incoming request. + HashPolicyTypeHeader HashPolicyType = iota + // HashPolicyTypeChannelID specifies to hash a unique Identifier of the + // Channel. In grpc-go, this will be done using the ClientConn pointer. + HashPolicyTypeChannelID +) + +// HashPolicy specifies the HashPolicy if the upstream cluster uses a hashing +// load balancer. +type HashPolicy struct { + HashPolicyType HashPolicyType + Terminal bool + // Fields used for type HEADER. + HeaderName string + Regex *regexp.Regexp + RegexSubstitution string +} + // Route is both a specification of how to match a request as well as an // indication of the action to take upon match. type Route struct { @@ -281,6 +306,8 @@ type Route struct { Headers []*HeaderMatcher Fraction *uint32 + HashPolicies []*HashPolicy + // If the matchers above indicate a match, the below configuration is used. WeightedClusters map[string]WeightedCluster // If MaxStreamDuration is nil, it indicates neither of the route action's @@ -579,7 +606,7 @@ func newWithConfig(config *bootstrap.Config, watchExpiryTimeout time.Duration) ( // BootstrapConfig returns the configuration read from the bootstrap file. // Callers must treat the return value as read-only. -func (c *Client) BootstrapConfig() *bootstrap.Config { +func (c *clientRefCounted) BootstrapConfig() *bootstrap.Config { return c.config } diff --git a/xds/internal/client/client_test.go b/xds/internal/xdsclient/client_test.go similarity index 90% rename from xds/internal/client/client_test.go rename to xds/internal/xdsclient/client_test.go index d57bc20e2c2c..abf51c45b83a 100644 --- a/xds/internal/client/client_test.go +++ b/xds/internal/xdsclient/client_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "context" @@ -34,9 +34,9 @@ import ( "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" ) @@ -187,59 +187,83 @@ func (s) TestWatchCallAnotherWatch(t *testing.T) { wantUpdate := ClusterUpdate{ClusterName: testEDSName} client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } wantUpdate2 := ClusterUpdate{ClusterName: testEDSName + "2"} client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate2}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate2, nil); err != nil { t.Fatal(err) } } -func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate) error { +func verifyListenerUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ListenerUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for listener update: %v", err) } gotUpdate := u.(ldsUpdateErr) + if wantErr != nil { + if gotUpdate.err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.err, wantErr) + } + return nil + } if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) } return nil } -func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate) error { +func verifyRouteConfigUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate RouteConfigUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for route configuration update: %v", err) } gotUpdate := u.(rdsUpdateErr) + if wantErr != nil { + if gotUpdate.err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.err, wantErr) + } + return nil + } if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { return fmt.Errorf("unexpected route config update: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) } return nil } -func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate) error { +func verifyClusterUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate ClusterUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for cluster update: %v", err) } gotUpdate := u.(clusterUpdateErr) - if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate) { + if wantErr != nil { + if gotUpdate.err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.err, wantErr) + } + return nil + } + if !cmp.Equal(gotUpdate.u, wantUpdate) { return fmt.Errorf("unexpected clusterUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) } return nil } -func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate) error { +func verifyEndpointsUpdate(ctx context.Context, updateCh *testutils.Channel, wantUpdate EndpointsUpdate, wantErr error) error { u, err := updateCh.Receive(ctx) if err != nil { return fmt.Errorf("timeout when waiting for endpoints update: %v", err) } gotUpdate := u.(endpointsUpdateErr) + if wantErr != nil { + if gotUpdate.err != wantErr { + return fmt.Errorf("unexpected error: %v, want %v", gotUpdate.err, wantErr) + } + return nil + } if gotUpdate.err != nil || !cmp.Equal(gotUpdate.u, wantUpdate, cmpopts.EquateEmpty()) { return fmt.Errorf("unexpected endpointsUpdate: (%v, %v), want: (%v, nil)", gotUpdate.u, gotUpdate.err, wantUpdate) } @@ -263,7 +287,7 @@ func (s) TestClientNewSingleton(t *testing.T) { defer cleanup() // The first New(). Should create a Client and a new APIClient. - client, err := New() + client, err := newRefCounted() if err != nil { t.Fatalf("failed to create client: %v", err) } @@ -280,7 +304,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // and should not create new API client. const count = 9 for i := 0; i < count; i++ { - tc, terr := New() + tc, terr := newRefCounted() if terr != nil { client.Close() t.Fatalf("%d-th call to New() failed with error: %v", i, terr) @@ -324,7 +348,7 @@ func (s) TestClientNewSingleton(t *testing.T) { // Call New() again after the previous Client is actually closed. Should // create a Client and a new APIClient. - client2, err2 := New() + client2, err2 := newRefCounted() if err2 != nil { t.Fatalf("failed to create client: %v", err) } diff --git a/xds/internal/client/dump.go b/xds/internal/xdsclient/dump.go similarity index 99% rename from xds/internal/client/dump.go rename to xds/internal/xdsclient/dump.go index 3fd18f6103b3..db9b474f370d 100644 --- a/xds/internal/client/dump.go +++ b/xds/internal/xdsclient/dump.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import anypb "github.com/golang/protobuf/ptypes/any" diff --git a/xds/internal/client/tests/dump_test.go b/xds/internal/xdsclient/dump_test.go similarity index 95% rename from xds/internal/client/tests/dump_test.go rename to xds/internal/xdsclient/dump_test.go index 815850973e3d..5c31b183a6bc 100644 --- a/xds/internal/client/tests/dump_test.go +++ b/xds/internal/xdsclient/dump_test.go @@ -18,7 +18,7 @@ * */ -package tests_test +package xdsclient_test import ( "fmt" @@ -39,9 +39,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const defaultTestWatchExpiryTimeout = 500 * time.Millisecond @@ -85,6 +85,7 @@ func (s) TestLDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpLDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -111,7 +112,7 @@ func (s) TestLDSConfigDump(t *testing.T) { Raw: r, } } - client.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewListeners(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpLDS, testVersion, want0); err != nil { @@ -120,7 +121,7 @@ func (s) TestLDSConfigDump(t *testing.T) { const nackVersion = "lds-version-nack" var nackErr = fmt.Errorf("lds nack error") - client.NewListeners( + updateHandler.NewListeners( map[string]xdsclient.ListenerUpdate{ ldsTargets[0]: {}, }, @@ -195,6 +196,7 @@ func (s) TestRDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpRDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -221,7 +223,7 @@ func (s) TestRDSConfigDump(t *testing.T) { Raw: r, } } - client.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewRouteConfigs(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpRDS, testVersion, want0); err != nil { @@ -230,7 +232,7 @@ func (s) TestRDSConfigDump(t *testing.T) { const nackVersion = "rds-version-nack" var nackErr = fmt.Errorf("rds nack error") - client.NewRouteConfigs( + updateHandler.NewRouteConfigs( map[string]xdsclient.RouteConfigUpdate{ rdsTargets[0]: {}, }, @@ -305,6 +307,7 @@ func (s) TestCDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpCDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -331,7 +334,7 @@ func (s) TestCDSConfigDump(t *testing.T) { Raw: r, } } - client.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewClusters(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpCDS, testVersion, want0); err != nil { @@ -340,7 +343,7 @@ func (s) TestCDSConfigDump(t *testing.T) { const nackVersion = "cds-version-nack" var nackErr = fmt.Errorf("cds nack error") - client.NewClusters( + updateHandler.NewClusters( map[string]xdsclient.ClusterUpdate{ cdsTargets[0]: {}, }, @@ -401,6 +404,7 @@ func (s) TestEDSConfigDump(t *testing.T) { t.Fatalf("failed to create client: %v", err) } defer client.Close() + updateHandler := client.(xdsclient.UpdateHandler) // Expected unknown. if err := compareDump(client.DumpEDS, "", map[string]xdsclient.UpdateWithMD{}); err != nil { @@ -427,7 +431,7 @@ func (s) TestEDSConfigDump(t *testing.T) { Raw: r, } } - client.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) + updateHandler.NewEndpoints(update0, xdsclient.UpdateMetadata{Version: testVersion}) // Expect ACK. if err := compareDump(client.DumpEDS, testVersion, want0); err != nil { @@ -436,7 +440,7 @@ func (s) TestEDSConfigDump(t *testing.T) { const nackVersion = "eds-version-nack" var nackErr = fmt.Errorf("eds nack error") - client.NewEndpoints( + updateHandler.NewEndpoints( map[string]xdsclient.EndpointsUpdate{ edsTargets[0]: {}, }, diff --git a/xds/internal/client/eds_test.go b/xds/internal/xdsclient/eds_test.go similarity index 99% rename from xds/internal/client/eds_test.go rename to xds/internal/xdsclient/eds_test.go index 467f25269cdf..ebd26a40f957 100644 --- a/xds/internal/client/eds_test.go +++ b/xds/internal/xdsclient/eds_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "fmt" diff --git a/xds/internal/client/errors.go b/xds/internal/xdsclient/errors.go similarity index 98% rename from xds/internal/client/errors.go rename to xds/internal/xdsclient/errors.go index 34ae2738db00..4d6cdaaf9b4a 100644 --- a/xds/internal/client/errors.go +++ b/xds/internal/xdsclient/errors.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import "fmt" diff --git a/xds/internal/client/filter_chain.go b/xds/internal/xdsclient/filter_chain.go similarity index 97% rename from xds/internal/client/filter_chain.go rename to xds/internal/xdsclient/filter_chain.go index 66d26d03b634..7089b97594a1 100644 --- a/xds/internal/client/filter_chain.go +++ b/xds/internal/xdsclient/filter_chain.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "errors" @@ -26,7 +26,6 @@ import ( v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" "github.com/golang/protobuf/proto" - "google.golang.org/grpc/xds/internal/version" ) @@ -50,14 +49,11 @@ const ( // FilterChain captures information from within a FilterChain message in a // Listener resource. -// -// Currently, this simply contains the security configuration found in the -// 'transport_socket' field of the filter chain. The actual set of filters -// associated with this filter chain are not captured here, since we do not -// support these filters on the server-side yet. type FilterChain struct { // SecurityCfg contains transport socket security configuration. SecurityCfg *SecurityConfig + // HTTPFilters represent the HTTP Filters that comprise this FilterChain. + HTTPFilters []HTTPFilter } // SourceType specifies the connection source IP match type. @@ -395,16 +391,20 @@ func (fci *FilterChainManager) addFilterChainsForSourcePorts(srcEntry *sourcePre } // filterChainFromProto extracts the relevant information from the FilterChain -// proto and stores it in our internal representation. Currently, we only -// process the security configuration stored in the transport_socket field. +// proto and stores it in our internal representation. func filterChainFromProto(fc *v3listenerpb.FilterChain) (*FilterChain, error) { + httpFilters, err := processNetworkFilters(fc.GetFilters()) + if err != nil { + return nil, err + } + filterChain := &FilterChain{HTTPFilters: httpFilters} // If the transport_socket field is not specified, it means that the control // plane has not sent us any security config. This is fine and the server // will use the fallback credentials configured as part of the // xdsCredentials. ts := fc.GetTransportSocket() if ts == nil { - return &FilterChain{}, nil + return filterChain, nil } if name := ts.GetName(); name != transportSocketName { return nil, fmt.Errorf("transport_socket field has unexpected name: %s", name) @@ -431,7 +431,8 @@ func filterChainFromProto(fc *v3listenerpb.FilterChain) (*FilterChain, error) { if sc.RequireClientCert && sc.RootInstanceName == "" { return nil, errors.New("security configuration on the server-side does not contain root certificate provider instance name, but require_client_cert field is set") } - return &FilterChain{SecurityCfg: sc}, nil + filterChain.SecurityCfg = sc + return filterChain, nil } // FilterChainLookupParams wraps parameters to be passed to Lookup. diff --git a/xds/internal/client/filter_chain_test.go b/xds/internal/xdsclient/filter_chain_test.go similarity index 76% rename from xds/internal/client/filter_chain_test.go rename to xds/internal/xdsclient/filter_chain_test.go index c68e22286763..e330a73a145b 100644 --- a/xds/internal/client/filter_chain_test.go +++ b/xds/internal/xdsclient/filter_chain_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -28,8 +28,11 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "google.golang.org/protobuf/testing/protocmp" "google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -37,6 +40,25 @@ import ( "google.golang.org/grpc/xds/internal/version" ) +var ( + emptyValidNetworkFilters = []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{}), + }, + }, + } + validServerSideHTTPFilter1 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } + validServerSideHTTPFilter2 = &v3httppb.HttpFilter{ + Name: "serverOnlyCustomFilter2", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: serverOnlyCustomFilterConfig}, + } +) + // TestNewFilterChainImpl_Failure_BadMatchFields verifies cases where we have a // single filter chain with match criteria that contains unsupported fields. func TestNewFilterChainImpl_Failure_BadMatchFields(t *testing.T) { @@ -150,11 +172,13 @@ func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, }, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -165,15 +189,19 @@ func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_ANY}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_EXTERNAL}, + Filters: emptyValidNetworkFilters, }, }, }, @@ -186,11 +214,13 @@ func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { FilterChainMatch: &v3listenerpb.FilterChainMatch{ SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16), cidrRangeFromAddressAndPrefixLen("10.0.0.0", 0)}, }, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.2.2", 16)}, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -201,12 +231,15 @@ func TestNewFilterChainImpl_Failure_OverlappingMatchingRules(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3, 4, 5}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{5, 6, 7}}, + Filters: emptyValidNetworkFilters, }, }, }, @@ -243,6 +276,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { TransportSocket: &v3corepb.TransportSocket{Name: "unsupported-transport-socket-name"}, + Filters: emptyValidNetworkFilters, }, }, }, @@ -259,6 +293,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { TypedConfig: testutils.MarshalAny(&v3tlspb.UpstreamTlsContext{}), }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -278,6 +313,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { }, }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -294,6 +330,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { TypedConfig: testutils.MarshalAny(&v3tlspb.DownstreamTlsContext{}), }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -318,6 +355,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -342,6 +380,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -360,6 +399,7 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -377,6 +417,349 @@ func TestNewFilterChainImpl_Failure_BadSecurityConfig(t *testing.T) { } } +// TestNewFilterChainImpl_Failure_BadHTTPFilters verifies cases where the HTTP +// Filters in the filter chain are invalid. +func TestNewFilterChainImpl_Failure_BadHTTPFilters(t *testing.T) { + tests := []struct { + name string + lis *v3listenerpb.Listener + wantErr string + }{ + { + name: "client side HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + { + name: "one valid then one invalid HTTP filter", + lis: &v3listenerpb.Listener{ + Name: "grpc/server?xds.resource.listening_address=0.0.0.0:9999", + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + { + Name: "clientOnlyCustomFilter", + ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, + }, + }, + }), + }, + }, + }, + }, + }, + }, + wantErr: "invalid server side HTTP Filters", + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + _, err := NewFilterChainManager(test.lis) + if err == nil || !strings.Contains(err.Error(), test.wantErr) { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: %s", err, test.wantErr) + } + }) + } +} + +// TestNewFilterChainImpl_Success_HTTPFilters tests the construction of the +// filter chain with valid HTTP Filters present. +func TestNewFilterChainImpl_Success_HTTPFilters(t *testing.T) { + tests := []struct { + name string + lis *v3listenerpb.Listener + wantFC *FilterChainManager + }{ + { + name: "singular valid http filter", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }}, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }}, + }, + }, + { + name: "two valid http filters", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }}, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }, + }, + }, + }, + // In the case of two HTTP Connection Manager's being present, the + // second HTTP Connection Manager should be validated, but ignored. + { + name: "two hcms", + lis: &v3listenerpb.Listener{ + FilterChains: []*v3listenerpb.FilterChain{ + { + Name: "filter-chain-1", + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + }, + }), + }, + }, + }, + }, + }, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: []*v3listenerpb.Filter{ + { + Name: "hcm", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + validServerSideHTTPFilter2, + }, + }), + }, + }, + { + Name: "hcm2", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ + HttpFilters: []*v3httppb.HttpFilter{ + validServerSideHTTPFilter1, + }, + }), + }, + }, + }, + }, + }, + wantFC: &FilterChainManager{ + dstPrefixMap: map[string]*destPrefixEntry{ + unspecifiedPrefixMapKey: { + srcTypeArr: [3]*sourcePrefixes{ + { + srcPrefixMap: map[string]*sourcePrefixEntry{ + unspecifiedPrefixMapKey: { + srcPortMap: map[int]*FilterChain{ + 0: {HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }}, + }, + }, + }, + }, + }, + }, + }, + def: &FilterChain{HTTPFilters: []HTTPFilter{ + { + Name: "serverOnlyCustomFilter", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + { + Name: "serverOnlyCustomFilter2", + Filter: serverOnlyHTTPFilter{}, + Config: filterConfig{Cfg: serverOnlyCustomFilterConfig}, + }, + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + gotFC, err := NewFilterChainManager(test.lis) + if err != nil { + t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) + } + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpOpts) { + t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) + } + }) + } +} + // TestNewFilterChainImpl_Success_SecurityConfig verifies cases where the // security configuration in the filter chain contains valid data. func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { @@ -390,10 +773,13 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { lis: &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { - Name: "filter-chain-1", + Name: "filter-chain-1", + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{ + Filters: emptyValidNetworkFilters, + }, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -432,6 +818,7 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, DefaultFilterChain: &v3listenerpb.FilterChain{ @@ -448,6 +835,7 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, wantFC: &FilterChainManager{ @@ -504,6 +892,7 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, DefaultFilterChain: &v3listenerpb.FilterChain{ @@ -528,6 +917,7 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, wantFC: &FilterChainManager{ @@ -573,7 +963,7 @@ func TestNewFilterChainImpl_Success_SecurityConfig(t *testing.T) { if err != nil { t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) } - if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{})) { + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) } }) @@ -610,7 +1000,8 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { lis: &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { - Name: "good-chain", + Name: "good-chain", + Filters: emptyValidNetworkFilters, }, { Name: "unsupported-destination-port", @@ -618,9 +1009,10 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, DestinationPort: &wrapperspb.UInt32Value{Value: 666}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -634,7 +1026,8 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { lis: &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { - Name: "good-chain", + Name: "good-chain", + Filters: emptyValidNetworkFilters, }, { Name: "unsupported-server-names", @@ -642,9 +1035,10 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, ServerNames: []string{"example-server"}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -661,7 +1055,8 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { lis: &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { - Name: "good-chain", + Name: "good-chain", + Filters: emptyValidNetworkFilters, }, { Name: "unsupported-transport-protocol", @@ -669,9 +1064,10 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, TransportProtocol: "tls", }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -688,7 +1084,8 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { lis: &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { - Name: "good-chain", + Name: "good-chain", + Filters: emptyValidNetworkFilters, }, { Name: "unsupported-application-protocol", @@ -696,9 +1093,10 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, ApplicationProtocols: []string{"h2"}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -718,7 +1116,7 @@ func TestNewFilterChainImpl_Success_UnsupportedMatchFields(t *testing.T) { if err != nil { t.Fatalf("NewFilterChainManager() returned err: %v, wantErr: nil", err) } - if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{})) { + if !cmp.Equal(gotFC, test.wantFC, cmp.AllowUnexported(FilterChainManager{}, destPrefixEntry{}, sourcePrefixes{}, sourcePrefixEntry{}), cmpopts.EquateEmpty()) { t.Fatalf("NewFilterChainManager() returned %+v, want: %+v", gotFC, test.wantFC) } }) @@ -740,6 +1138,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { { // Unspecified destination prefix. FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, }, { // v4 wildcard destination prefix. @@ -747,6 +1146,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, }, + Filters: emptyValidNetworkFilters, }, { // v6 wildcard destination prefix. @@ -754,15 +1154,18 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, }, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -849,15 +1252,17 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -901,15 +1306,17 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -952,6 +1359,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ @@ -960,9 +1368,10 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, SourcePorts: []uint32{1, 2, 3}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -1010,15 +1419,18 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 8)}, TransportProtocol: "raw_buffer", }, + Filters: emptyValidNetworkFilters, }, { // This chain will be dropped in favor of the above @@ -1032,6 +1444,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("10.0.0.0", 16)}, }, + Filters: emptyValidNetworkFilters, }, { // This chain will be dropped for unsupported server @@ -1040,6 +1453,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, ServerNames: []string{"foo", "bar"}, }, + Filters: emptyValidNetworkFilters, }, { // This chain will be dropped for unsupported transport @@ -1048,6 +1462,7 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.2", 32)}, TransportProtocol: "not-raw-buffer", }, + Filters: emptyValidNetworkFilters, }, { // This chain will be dropped for unsupported @@ -1056,9 +1471,10 @@ func TestNewFilterChainImpl_Success_AllCombinations(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.3", 32)}, ApplicationProtocols: []string{"h2"}, }, + Filters: emptyValidNetworkFilters, }, }, - DefaultFilterChain: &v3listenerpb.FilterChain{}, + DefaultFilterChain: &v3listenerpb.FilterChain{Filters: emptyValidNetworkFilters}, }, wantFC: &FilterChainManager{ dstPrefixMap: map[string]*destPrefixEntry{ @@ -1147,6 +1563,7 @@ func TestLookup_Failures(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1165,6 +1582,7 @@ func TestLookup_Failures(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1184,6 +1602,7 @@ func TestLookup_Failures(t *testing.T) { SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 24)}, SourceType: v3listenerpb.FilterChainMatch_SAME_IP_OR_LOOPBACK, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1200,12 +1619,14 @@ func TestLookup_Failures(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}, SourcePorts: []uint32{1}, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1224,6 +1645,7 @@ func TestLookup_Failures(t *testing.T) { FilterChains: []*v3listenerpb.FilterChain{ { FilterChainMatch: &v3listenerpb.FilterChainMatch{SourcePorts: []uint32{1, 2, 3}}, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1246,11 +1668,13 @@ func TestLookup_Failures(t *testing.T) { PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.1", 32)}, ServerNames: []string{"foo"}, }, + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.100.0", 16)}, }, + Filters: emptyValidNetworkFilters, }, }, }, @@ -1293,6 +1717,7 @@ func TestLookup_Successes(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, }, // A default filter chain with an empty transport socket. @@ -1307,12 +1732,14 @@ func TestLookup_Successes(t *testing.T) { }), }, }, + Filters: emptyValidNetworkFilters, }, } lisWithoutDefaultChain := &v3listenerpb.Listener{ FilterChains: []*v3listenerpb.FilterChain{ { TransportSocket: transportSocketWithInstanceName("unspecified-dest-and-source-prefix"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ @@ -1320,16 +1747,19 @@ func TestLookup_Successes(t *testing.T) { SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("0.0.0.0", 0)}, }, TransportSocket: transportSocketWithInstanceName("wildcard-prefixes-v4"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ SourcePrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("::", 0)}, }, TransportSocket: transportSocketWithInstanceName("wildcard-source-prefix-v6"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{PrefixRanges: []*v3corepb.CidrRange{cidrRangeFromAddressAndPrefixLen("192.168.1.1", 16)}}, TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-unspecified-source-type"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ @@ -1337,6 +1767,7 @@ func TestLookup_Successes(t *testing.T) { SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, }, TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ @@ -1345,6 +1776,7 @@ func TestLookup_Successes(t *testing.T) { SourceType: v3listenerpb.FilterChainMatch_EXTERNAL, }, TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix"), + Filters: emptyValidNetworkFilters, }, { FilterChainMatch: &v3listenerpb.FilterChainMatch{ @@ -1354,6 +1786,7 @@ func TestLookup_Successes(t *testing.T) { SourcePorts: []uint32{80}, }, TransportSocket: transportSocketWithInstanceName("specific-destination-prefix-specific-source-type-specific-source-prefix-specific-source-port"), + Filters: emptyValidNetworkFilters, }, }, } @@ -1462,7 +1895,7 @@ func TestLookup_Successes(t *testing.T) { if err != nil { t.Fatalf("FilterChainManager.Lookup(%v) failed: %v", test.params, err) } - if !cmp.Equal(gotFC, test.wantFC) { + if !cmp.Equal(gotFC, test.wantFC, cmpopts.EquateEmpty()) { t.Fatalf("FilterChainManager.Lookup(%v) = %v, want %v", test.params, gotFC, test.wantFC) } }) @@ -1480,9 +1913,10 @@ func (fci *FilterChainManager) Equal(other *FilterChainManager) bool { return true } switch { - case !cmp.Equal(fci.dstPrefixMap, other.dstPrefixMap): + case !cmp.Equal(fci.dstPrefixMap, other.dstPrefixMap, cmpopts.EquateEmpty()): return false - case !cmp.Equal(fci.def, other.def): + // TODO: Support comparing dstPrefixes slice? + case !cmp.Equal(fci.def, other.def, cmpopts.EquateEmpty(), protocmp.Transform()): return false } return true @@ -1499,7 +1933,7 @@ func (dpe *destPrefixEntry) Equal(other *destPrefixEntry) bool { return false } for i, st := range dpe.srcTypeArr { - if !cmp.Equal(st, other.srcTypeArr[i]) { + if !cmp.Equal(st, other.srcTypeArr[i], cmpopts.EquateEmpty()) { return false } } @@ -1513,7 +1947,8 @@ func (sp *sourcePrefixes) Equal(other *sourcePrefixes) bool { if sp == nil { return true } - return cmp.Equal(sp.srcPrefixMap, other.srcPrefixMap) + // TODO: Support comparing srcPrefixes slice? + return cmp.Equal(sp.srcPrefixMap, other.srcPrefixMap, cmpopts.EquateEmpty()) } func (spe *sourcePrefixEntry) Equal(other *sourcePrefixEntry) bool { @@ -1526,7 +1961,7 @@ func (spe *sourcePrefixEntry) Equal(other *sourcePrefixEntry) bool { switch { case !cmp.Equal(spe.net, other.net): return false - case !cmp.Equal(spe.srcPortMap, other.srcPortMap): + case !cmp.Equal(spe.srcPortMap, other.srcPortMap, cmpopts.EquateEmpty(), protocmp.Transform()): return false } return true diff --git a/xds/internal/client/lds_test.go b/xds/internal/xdsclient/lds_test.go similarity index 95% rename from xds/internal/client/lds_test.go rename to xds/internal/xdsclient/lds_test.go index ad9af4c885a2..012efd16d7b1 100644 --- a/xds/internal/client/lds_test.go +++ b/xds/internal/xdsclient/lds_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -34,7 +34,6 @@ import ( "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/grpc/internal/testutils" - "google.golang.org/grpc/internal/xds/env" "google.golang.org/grpc/xds/internal/httpfilter" "google.golang.org/grpc/xds/internal/version" @@ -186,7 +185,6 @@ func (s) TestUnmarshalListener_ClientSide(t *testing.T) { wantUpdate map[string]ListenerUpdate wantMD UpdateMetadata wantErr bool - disableFI bool // disable fault injection }{ { name: "non-listener resource", @@ -358,18 +356,6 @@ func (s) TestUnmarshalListener_ClientSide(t *testing.T) { Version: testVersion, }, }, - { - name: "v3 with custom filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(customFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: {RouteConfigName: v3RouteConfigName, MaxStreamDuration: time.Second, Raw: v3LisWithFilters(customFilter)}, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, { name: "v3 with two filters with same name", resources: []*anypb.Any{v3LisWithFilters(customFilter, customFilter)}, @@ -477,22 +463,6 @@ func (s) TestUnmarshalListener_ClientSide(t *testing.T) { Version: testVersion, }, }, - { - name: "v3 with error filter, fault injection disabled", - resources: []*anypb.Any{v3LisWithFilters(errFilter)}, - wantUpdate: map[string]ListenerUpdate{ - v3LDSTarget: { - RouteConfigName: v3RouteConfigName, - MaxStreamDuration: time.Second, - Raw: v3LisWithFilters(errFilter), - }, - }, - wantMD: UpdateMetadata{ - Status: ServiceStatusACKed, - Version: testVersion, - }, - disableFI: true, - }, { name: "v2 listener resource", resources: []*anypb.Any{v2Lis}, @@ -572,9 +542,6 @@ func (s) TestUnmarshalListener_ClientSide(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - update, md, err := UnmarshalListener(testVersion, test.resources, nil) if (err != nil) != test.wantErr { t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) @@ -585,7 +552,6 @@ func (s) TestUnmarshalListener_ClientSide(t *testing.T) { if diff := cmp.Diff(md, test.wantMD, cmpOptsIgnoreDetails); diff != "" { t.Errorf("got unexpected metadata, diff (-got +want): %v", diff) } - env.FaultInjectionSupport = oldFI }) } } @@ -951,36 +917,6 @@ func (s) TestUnmarshalListener_ServerSide(t *testing.T) { wantMD: errMD, wantErr: "failed unmarshaling of network filter", }, - { - name: "client only http filter inside the network filter", - resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ - Name: v3LDSTarget, - Address: localSocketAddress, - FilterChains: []*v3listenerpb.FilterChain{ - { - Name: "filter-chain-1", - Filters: []*v3listenerpb.Filter{ - { - Name: "hcm", - ConfigType: &v3listenerpb.Filter_TypedConfig{ - TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{ - HttpFilters: []*v3httppb.HttpFilter{ - { - Name: "clientOnlyCustomFilter", - ConfigType: &v3httppb.HttpFilter_TypedConfig{TypedConfig: clientOnlyCustomFilterConfig}, - }, - }, - }), - }, - }, - }, - }, - }, - })}, - wantUpdate: map[string]ListenerUpdate{v3LDSTarget: {}}, - wantMD: errMD, - wantErr: "not supported server-side", - }, { name: "unexpected transport socket name", resources: []*anypb.Any{testutils.MarshalAny(&v3listenerpb.Listener{ @@ -1287,10 +1223,6 @@ func (s) TestUnmarshalListener_ServerSide(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = true - defer func() { env.FaultInjectionSupport = oldFI }() - gotUpdate, md, err := UnmarshalListener(testVersion, test.resources, nil) if (err != nil) != (test.wantErr != "") { t.Fatalf("UnmarshalListener(), got err: %v, wantErr: %v", err, test.wantErr) diff --git a/xds/internal/client/load/reporter.go b/xds/internal/xdsclient/load/reporter.go similarity index 100% rename from xds/internal/client/load/reporter.go rename to xds/internal/xdsclient/load/reporter.go diff --git a/xds/internal/client/load/store.go b/xds/internal/xdsclient/load/store.go similarity index 100% rename from xds/internal/client/load/store.go rename to xds/internal/xdsclient/load/store.go diff --git a/xds/internal/client/load/store_test.go b/xds/internal/xdsclient/load/store_test.go similarity index 100% rename from xds/internal/client/load/store_test.go rename to xds/internal/xdsclient/load/store_test.go diff --git a/xds/internal/client/loadreport.go b/xds/internal/xdsclient/loadreport.go similarity index 98% rename from xds/internal/client/loadreport.go rename to xds/internal/xdsclient/loadreport.go index be42a6e0c383..32a71dada7fb 100644 --- a/xds/internal/client/loadreport.go +++ b/xds/internal/xdsclient/loadreport.go @@ -15,13 +15,13 @@ * limitations under the License. */ -package client +package xdsclient import ( "context" "google.golang.org/grpc" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" ) // ReportLoad starts an load reporting stream to the given server. If the server diff --git a/xds/internal/client/tests/loadreport_test.go b/xds/internal/xdsclient/loadreport_test.go similarity index 94% rename from xds/internal/client/tests/loadreport_test.go rename to xds/internal/xdsclient/loadreport_test.go index b1ec37294771..0f29e96fc1eb 100644 --- a/xds/internal/client/tests/loadreport_test.go +++ b/xds/internal/xdsclient/loadreport_test.go @@ -18,7 +18,7 @@ * */ -package tests_test +package xdsclient_test import ( "context" @@ -34,13 +34,13 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" - "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" "google.golang.org/protobuf/testing/protocmp" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. ) const ( @@ -56,7 +56,7 @@ func (s) TestLRSClient(t *testing.T) { } defer sCleanup() - xdsC, err := client.NewWithConfigForTesting(&bootstrap.Config{ + xdsC, err := xdsclient.NewWithConfigForTesting(&bootstrap.Config{ BalancerName: fs.Address, Creds: grpc.WithTransportCredentials(insecure.NewCredentials()), NodeProto: &v2corepb.Node{}, diff --git a/xds/internal/client/logging.go b/xds/internal/xdsclient/logging.go similarity index 98% rename from xds/internal/client/logging.go rename to xds/internal/xdsclient/logging.go index bff3fb1d3df3..e28ea0d04103 100644 --- a/xds/internal/client/logging.go +++ b/xds/internal/xdsclient/logging.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" diff --git a/xds/internal/client/rds_test.go b/xds/internal/xdsclient/rds_test.go similarity index 83% rename from xds/internal/client/rds_test.go rename to xds/internal/xdsclient/rds_test.go index 33bcc2ad13f8..15a2afee1f7f 100644 --- a/xds/internal/client/rds_test.go +++ b/xds/internal/xdsclient/rds_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -88,7 +88,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { rc *v3routepb.RouteConfiguration wantUpdate RouteConfigUpdate wantError bool - disableFI bool // disable fault injection }{ { name: "default-route-match-field-is-nil", @@ -474,19 +473,10 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { rc: goodRouteConfigWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("unknown.custom.filter")}), wantUpdate: goodUpdateWithFilterConfigs(nil), }, - { - name: "good-route-config-with-http-err-filter-config-fi-disabled", - disableFI: true, - rc: goodRouteConfigWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantUpdate: goodUpdateWithFilterConfigs(nil), - }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !test.disableFI - gotUpdate, gotError := generateRDSUpdateFromRouteConfiguration(test.rc, nil, false) if (gotError != nil) != test.wantError || !cmp.Equal(gotUpdate, test.wantUpdate, cmpopts.EquateEmpty(), @@ -494,8 +484,6 @@ func (s) TestRDSGenerateRDSUpdateFromRouteConfiguration(t *testing.T) { return fmt.Sprint(fc) })) { t.Errorf("generateRDSUpdateFromRouteConfiguration(%+v, %v) returned unexpected, diff (-want +got):\\n%s", test.rc, ldsTarget, cmp.Diff(test.wantUpdate, gotUpdate, cmpopts.EquateEmpty())) - - env.FaultInjectionSupport = oldFI } }) } @@ -815,7 +803,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes []*v3routepb.Route wantRoutes []*Route wantErr bool - disableFI bool // disable fault injection }{ { name: "no path", @@ -1167,6 +1154,119 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { }}, wantErr: false, }, + { + name: "good-with-channel-id-hash-policy", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + }}, + wantErr: false, + }, + // This tests that policy.Regex ends up being nil if RegexRewrite is not + // set in xds response. + { + name: "good-with-header-hash-policy-no-regex-specified", + routes: []*v3routepb.Route{ + { + Match: &v3routepb.RouteMatch{ + PathSpecifier: &v3routepb.RouteMatch_Prefix{Prefix: "/a/"}, + Headers: []*v3routepb.HeaderMatcher{ + { + Name: "th", + HeaderMatchSpecifier: &v3routepb.HeaderMatcher_PrefixMatch{ + PrefixMatch: "tv", + }, + InvertMatch: true, + }, + }, + RuntimeFraction: &v3corepb.RuntimeFractionalPercent{ + DefaultValue: &v3typepb.FractionalPercent{ + Numerator: 1, + Denominator: v3typepb.FractionalPercent_HUNDRED, + }, + }, + }, + Action: &v3routepb.Route_Route{ + Route: &v3routepb.RouteAction{ + ClusterSpecifier: &v3routepb.RouteAction_WeightedClusters{ + WeightedClusters: &v3routepb.WeightedCluster{ + Clusters: []*v3routepb.WeightedCluster_ClusterWeight{ + {Name: "B", Weight: &wrapperspb.UInt32Value{Value: 60}}, + {Name: "A", Weight: &wrapperspb.UInt32Value{Value: 40}}, + }, + TotalWeight: &wrapperspb.UInt32Value{Value: 100}, + }}, + HashPolicy: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{Header: &v3routepb.RouteAction_HashPolicy_Header{HeaderName: ":path"}}}, + }, + }}, + }, + }, + wantRoutes: []*Route{{ + Prefix: newStringP("/a/"), + Headers: []*HeaderMatcher{ + { + Name: "th", + InvertMatch: newBoolP(true), + PrefixMatch: newStringP("tv"), + }, + }, + Fraction: newUInt32P(10000), + WeightedClusters: map[string]WeightedCluster{"A": {Weight: 40}, "B": {Weight: 60}}, + HashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path"}, + }, + }}, + wantErr: false, + }, { name: "with custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), @@ -1182,12 +1282,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("custom.filter")}), wantRoutes: goodUpdateWithFilterConfigs(map[string]httpfilter.FilterConfig{"foo": filterConfig{Override: customFilterConfig}}), }, - { - name: "with custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": customFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with erroring custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), @@ -1198,12 +1292,6 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": wrappedOptionalFilter("err.custom.filter")}), wantErr: true, }, - { - name: "with erroring custom HTTP filter config, FI disabled", - disableFI: true, - routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": errFilterConfig}), - wantRoutes: goodUpdateWithFilterConfigs(nil), - }, { name: "with unknown custom HTTP filter config", routes: goodRouteWithFilterConfigs(map[string]*anypb.Any{"foo": unknownFilterConfig}), @@ -1223,13 +1311,11 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { return fmt.Sprint(fc) }), } - + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - oldFI := env.FaultInjectionSupport - env.FaultInjectionSupport = !tt.disableFI - defer func() { env.FaultInjectionSupport = oldFI }() - got, err := routesProtoToSlice(tt.routes, nil, false) if (err != nil) != tt.wantErr { t.Fatalf("routesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) @@ -1241,6 +1327,119 @@ func (s) TestRoutesProtoToSlice(t *testing.T) { } } +func (s) TestHashPoliciesProtoToSlice(t *testing.T) { + tests := []struct { + name string + hashPolicies []*v3routepb.RouteAction_HashPolicy + wantHashPolicies []*HashPolicy + wantErr bool + }{ + // header-hash-policy tests a basic hash policy that specifies to hash a + // certain header. + { + name: "header-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + }, + }, + // channel-id-hash-policy tests a basic hash policy that specifies to + // hash a unique identifier of the channel. + { + name: "channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}}, + }, + wantHashPolicies: []*HashPolicy{ + {HashPolicyType: HashPolicyTypeChannelID}, + }, + }, + // unsupported-filter-state-key tests that an unsupported key in the + // filter state hash policy are treated as a no-op. + { + name: "wrong-filter-state-key", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "unsupported key"}}}, + }, + }, + // no-op-hash-policy tests that hash policies that are not supported by + // grpc are treated as a no-op. + { + name: "no-op-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + {PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{}}, + }, + }, + // header-and-channel-id-hash-policy test that a list of header and + // channel id hash policies are successfully converted to an internal + // struct. + { + name: "header-and-channel-id-hash-policy", + hashPolicies: []*v3routepb.RouteAction_HashPolicy{ + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_Header_{ + Header: &v3routepb.RouteAction_HashPolicy_Header{ + HeaderName: ":path", + RegexRewrite: &v3matcherpb.RegexMatchAndSubstitute{ + Pattern: &v3matcherpb.RegexMatcher{Regex: "/products"}, + Substitution: "/products", + }, + }, + }, + }, + { + PolicySpecifier: &v3routepb.RouteAction_HashPolicy_FilterState_{FilterState: &v3routepb.RouteAction_HashPolicy_FilterState{Key: "io.grpc.channel_id"}}, + Terminal: true, + }, + }, + wantHashPolicies: []*HashPolicy{ + { + HashPolicyType: HashPolicyTypeHeader, + HeaderName: ":path", + Regex: func() *regexp.Regexp { return regexp.MustCompile("/products") }(), + RegexSubstitution: "/products", + }, + { + HashPolicyType: HashPolicyTypeChannelID, + Terminal: true, + }, + }, + }, + } + + oldRingHashSupport := env.RingHashSupport + env.RingHashSupport = true + defer func() { env.RingHashSupport = oldRingHashSupport }() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := hashPoliciesProtoToSlice(tt.hashPolicies, nil) + if (err != nil) != tt.wantErr { + t.Fatalf("hashPoliciesProtoToSlice() error = %v, wantErr %v", err, tt.wantErr) + } + if diff := cmp.Diff(got, tt.wantHashPolicies, cmp.AllowUnexported(regexp.Regexp{})); diff != "" { + t.Fatalf("hashPoliciesProtoToSlice() returned unexpected diff (-got +want):\n%s", diff) + } + }) + } +} + func newStringP(s string) *string { return &s } diff --git a/xds/internal/client/requests_counter.go b/xds/internal/xdsclient/requests_counter.go similarity index 54% rename from xds/internal/client/requests_counter.go rename to xds/internal/xdsclient/requests_counter.go index f033e1920991..beed2e9d0add 100644 --- a/xds/internal/client/requests_counter.go +++ b/xds/internal/xdsclient/requests_counter.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -24,44 +24,53 @@ import ( "sync/atomic" ) -type servicesRequestsCounter struct { +type clusterNameAndServiceName struct { + clusterName, edsServcieName string +} + +type clusterRequestsCounter struct { mu sync.Mutex - services map[string]*ServiceRequestsCounter + clusters map[clusterNameAndServiceName]*ClusterRequestsCounter } -var src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), +var src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), } -// ServiceRequestsCounter is used to track the total inflight requests for a +// ClusterRequestsCounter is used to track the total inflight requests for a // service with the provided name. -type ServiceRequestsCounter struct { - ServiceName string - numRequests uint32 +type ClusterRequestsCounter struct { + ClusterName string + EDSServiceName string + numRequests uint32 } -// GetServiceRequestsCounter returns the ServiceRequestsCounter with the +// GetClusterRequestsCounter returns the ClusterRequestsCounter with the // provided serviceName. If one does not exist, it creates it. -func GetServiceRequestsCounter(serviceName string) *ServiceRequestsCounter { +func GetClusterRequestsCounter(clusterName, edsServiceName string) *ClusterRequestsCounter { src.mu.Lock() defer src.mu.Unlock() - c, ok := src.services[serviceName] + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] if !ok { - c = &ServiceRequestsCounter{ServiceName: serviceName} - src.services[serviceName] = c + c = &ClusterRequestsCounter{ClusterName: clusterName} + src.clusters[k] = c } return c } -// StartRequest starts a request for a service, incrementing its number of +// StartRequest starts a request for a cluster, incrementing its number of // requests by 1. Returns an error if the max number of requests is exceeded. -func (c *ServiceRequestsCounter) StartRequest(max uint32) error { +func (c *ClusterRequestsCounter) StartRequest(max uint32) error { // Note that during race, the limits could be exceeded. This is allowed: // "Since the implementation is eventually consistent, races between threads // may allow limits to be potentially exceeded." // https://www.envoyproxy.io/docs/envoy/latest/intro/arch_overview/upstream/circuit_breaking#arch-overview-circuit-break. if atomic.LoadUint32(&c.numRequests) >= max { - return fmt.Errorf("max requests %v exceeded on service %v", max, c.ServiceName) + return fmt.Errorf("max requests %v exceeded on service %v", max, c.ClusterName) } atomic.AddUint32(&c.numRequests, 1) return nil @@ -69,18 +78,30 @@ func (c *ServiceRequestsCounter) StartRequest(max uint32) error { // EndRequest ends a request for a service, decrementing its number of requests // by 1. -func (c *ServiceRequestsCounter) EndRequest() { +func (c *ClusterRequestsCounter) EndRequest() { atomic.AddUint32(&c.numRequests, ^uint32(0)) } // ClearCounterForTesting clears the counter for the service. Should be only // used in tests. -func ClearCounterForTesting(serviceName string) { +func ClearCounterForTesting(clusterName, edsServiceName string) { src.mu.Lock() defer src.mu.Unlock() - c, ok := src.services[serviceName] + k := clusterNameAndServiceName{ + clusterName: clusterName, + edsServcieName: edsServiceName, + } + c, ok := src.clusters[k] if !ok { return } c.numRequests = 0 } + +// ClearAllCountersForTesting clears all the counters. Should be only used in +// tests. +func ClearAllCountersForTesting() { + src.mu.Lock() + defer src.mu.Unlock() + src.clusters = make(map[clusterNameAndServiceName]*ClusterRequestsCounter) +} diff --git a/xds/internal/client/requests_counter_test.go b/xds/internal/xdsclient/requests_counter_test.go similarity index 81% rename from xds/internal/client/requests_counter_test.go rename to xds/internal/xdsclient/requests_counter_test.go index 30892fc747a0..cd95aeaf82e0 100644 --- a/xds/internal/client/requests_counter_test.go +++ b/xds/internal/xdsclient/requests_counter_test.go @@ -18,7 +18,7 @@ * */ -package client +package xdsclient import ( "sync" @@ -26,6 +26,8 @@ import ( "testing" ) +const testService = "test-service-name" + type counterTest struct { name string maxRequests uint32 @@ -51,9 +53,9 @@ var tests = []counterTest{ }, } -func resetServiceRequestsCounter() { - src = &servicesRequestsCounter{ - services: make(map[string]*ServiceRequestsCounter), +func resetClusterRequestsCounter() { + src = &clusterRequestsCounter{ + clusters: make(map[clusterNameAndServiceName]*ClusterRequestsCounter), } } @@ -67,7 +69,7 @@ func testCounter(t *testing.T, test counterTest) { var successes, errors uint32 for i := 0; i < int(test.numRequests); i++ { go func() { - counter := GetServiceRequestsCounter(test.name) + counter := GetClusterRequestsCounter(test.name, testService) defer requestsDone.Done() err := counter.StartRequest(test.maxRequests) if err == nil { @@ -103,7 +105,7 @@ func testCounter(t *testing.T, test counterTest) { } func (s) TestRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() + defer resetClusterRequestsCounter() for _, test := range tests { t.Run(test.name, func(t *testing.T) { testCounter(t, test) @@ -111,18 +113,18 @@ func (s) TestRequestsCounter(t *testing.T) { } } -func (s) TestGetServiceRequestsCounter(t *testing.T) { - defer resetServiceRequestsCounter() +func (s) TestGetClusterRequestsCounter(t *testing.T) { + defer resetClusterRequestsCounter() for _, test := range tests { - counterA := GetServiceRequestsCounter(test.name) - counterB := GetServiceRequestsCounter(test.name) + counterA := GetClusterRequestsCounter(test.name, testService) + counterB := GetClusterRequestsCounter(test.name, testService) if counterA != counterB { t.Errorf("counter %v %v != counter %v %v", counterA, *counterA, counterB, *counterB) } } } -func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsCounter) { +func startRequests(t *testing.T, n uint32, max uint32, counter *ClusterRequestsCounter) { for i := uint32(0); i < n; i++ { if err := counter.StartRequest(max); err != nil { t.Fatalf("error starting initial request: %v", err) @@ -131,11 +133,11 @@ func startRequests(t *testing.T, n uint32, max uint32, counter *ServiceRequestsC } func (s) TestSetMaxRequestsIncreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-increased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-increased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax, initialMax, counter) if err := counter.StartRequest(initialMax); err == nil { t.Fatal("unexpected success on start request after max met") @@ -148,11 +150,11 @@ func (s) TestSetMaxRequestsIncreased(t *testing.T) { } func (s) TestSetMaxRequestsDecreased(t *testing.T) { - defer resetServiceRequestsCounter() - const serviceName string = "set-max-requests-decreased" + defer resetClusterRequestsCounter() + const clusterName string = "set-max-requests-decreased" var initialMax uint32 = 16 - counter := GetServiceRequestsCounter(serviceName) + counter := GetClusterRequestsCounter(clusterName, testService) startRequests(t, initialMax-1, initialMax, counter) newMax := initialMax - 1 diff --git a/xds/internal/client/singleton.go b/xds/internal/xdsclient/singleton.go similarity index 63% rename from xds/internal/client/singleton.go rename to xds/internal/xdsclient/singleton.go index 99f195341acd..f045790e2a40 100644 --- a/xds/internal/client/singleton.go +++ b/xds/internal/xdsclient/singleton.go @@ -16,32 +16,30 @@ * */ -package client +package xdsclient import ( + "bytes" + "encoding/json" "fmt" "sync" "time" - "google.golang.org/grpc/xds/internal/client/bootstrap" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const defaultWatchExpiryTimeout = 15 * time.Second // This is the Client returned by New(). It contains one client implementation, // and maintains the refcount. -var singletonClient = &Client{} +var singletonClient = &clientRefCounted{} // To override in tests. var bootstrapNewConfig = bootstrap.NewConfig -// Client is a full fledged gRPC client which queries a set of discovery APIs -// (collectively termed as xDS) on a remote management server, to discover -// various dynamic resources. -// -// The xds client is a singleton. It will be shared by the xds resolver and +// clientRefCounted is ref-counted, and to be shared by the xds resolver and // balancer implementations, across multiple ClientConns and Servers. -type Client struct { +type clientRefCounted struct { *clientImpl // This mu protects all the fields, including the embedded clientImpl above. @@ -58,7 +56,18 @@ type Client struct { // Note that the first invocation of New() or NewWithConfig() sets the client // singleton. The following calls will return the singleton xds client without // checking or using the config. -func New() (*Client, error) { +func New() (XDSClient, error) { + // This cannot just return newRefCounted(), because in error cases, the + // returned nil is a typed nil (*clientRefCounted), which may cause nil + // checks fail. + c, err := newRefCounted() + if err != nil { + return nil, err + } + return c, nil +} + +func newRefCounted() (*clientRefCounted, error) { singletonClient.mu.Lock() defer singletonClient.mu.Unlock() // If the client implementation was created, increment ref count and return @@ -92,9 +101,9 @@ func New() (*Client, error) { // singleton. The following calls will return the singleton xds client without // checking or using the config. // -// This function is internal only, for c2p resolver to use. DO NOT use this -// elsewhere. Use New() instead. -func NewWithConfig(config *bootstrap.Config) (*Client, error) { +// This function is internal only, for c2p resolver and testing to use. DO NOT +// use this elsewhere. Use New() instead. +func NewWithConfig(config *bootstrap.Config) (XDSClient, error) { singletonClient.mu.Lock() defer singletonClient.mu.Unlock() // If the client implementation was created, increment ref count and return @@ -118,7 +127,7 @@ func NewWithConfig(config *bootstrap.Config) (*Client, error) { // Close closes the client. It does ref count of the xds client implementation, // and closes the gRPC connection to the management server when ref count // reaches 0. -func (c *Client) Close() { +func (c *clientRefCounted) Close() { c.mu.Lock() defer c.mu.Unlock() c.refCount-- @@ -134,10 +143,56 @@ func (c *Client) Close() { // // Note that this function doesn't set the singleton, so that the testing states // don't leak. -func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (*Client, error) { +func NewWithConfigForTesting(config *bootstrap.Config, watchExpiryTimeout time.Duration) (XDSClient, error) { cl, err := newWithConfig(config, watchExpiryTimeout) if err != nil { return nil, err } - return &Client{clientImpl: cl, refCount: 1}, nil + return &clientRefCounted{clientImpl: cl, refCount: 1}, nil } + +// NewClientWithBootstrapContents returns an xds client for this config, +// separate from the global singleton. This should be used for testing +// purposes only. +func NewClientWithBootstrapContents(contents []byte) (XDSClient, error) { + // Normalize the contents + buf := bytes.Buffer{} + err := json.Indent(&buf, contents, "", "") + if err != nil { + return nil, fmt.Errorf("xds: error normalizing JSON: %v", err) + } + contents = bytes.TrimSpace(buf.Bytes()) + + clientsMu.Lock() + defer clientsMu.Unlock() + if c := clients[string(contents)]; c != nil { + c.mu.Lock() + // Since we don't remove the *Client from the map when it is closed, we + // need to recreate the impl if the ref count dropped to zero. + if c.refCount > 0 { + c.refCount++ + c.mu.Unlock() + return c, nil + } + c.mu.Unlock() + } + + bcfg, err := bootstrap.NewConfigFromContents(contents) + if err != nil { + return nil, fmt.Errorf("xds: error with bootstrap config: %v", err) + } + + cImpl, err := newWithConfig(bcfg, defaultWatchExpiryTimeout) + if err != nil { + return nil, err + } + + c := &clientRefCounted{clientImpl: cImpl, refCount: 1} + clients[string(contents)] = c + return c, nil +} + +var ( + clients = map[string]*clientRefCounted{} + clientsMu sync.Mutex +) diff --git a/xds/internal/client/transport_helper.go b/xds/internal/xdsclient/transport_helper.go similarity index 99% rename from xds/internal/client/transport_helper.go rename to xds/internal/xdsclient/transport_helper.go index 671e5b3220f1..1e3caa606d1d 100644 --- a/xds/internal/client/transport_helper.go +++ b/xds/internal/xdsclient/transport_helper.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "context" @@ -24,7 +24,7 @@ import ( "time" "github.com/golang/protobuf/proto" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" "google.golang.org/grpc" "google.golang.org/grpc/internal/buffer" diff --git a/xds/internal/client/v2/ack_test.go b/xds/internal/xdsclient/v2/ack_test.go similarity index 99% rename from xds/internal/client/v2/ack_test.go rename to xds/internal/xdsclient/v2/ack_test.go index 53c8cef189d5..1c41cd9d1ad8 100644 --- a/xds/internal/client/v2/ack_test.go +++ b/xds/internal/xdsclient/v2/ack_test.go @@ -33,9 +33,9 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( diff --git a/xds/internal/client/v2/cds_test.go b/xds/internal/xdsclient/v2/cds_test.go similarity index 99% rename from xds/internal/client/v2/cds_test.go rename to xds/internal/xdsclient/v2/cds_test.go index ba7f627b25ff..da4738460930 100644 --- a/xds/internal/client/v2/cds_test.go +++ b/xds/internal/xdsclient/v2/cds_test.go @@ -28,8 +28,8 @@ import ( corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" anypb "github.com/golang/protobuf/ptypes/any" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) const ( diff --git a/xds/internal/client/v2/client.go b/xds/internal/xdsclient/v2/client.go similarity index 98% rename from xds/internal/client/v2/client.go rename to xds/internal/xdsclient/v2/client.go index b91482d13cf9..766db2564b8d 100644 --- a/xds/internal/client/v2/client.go +++ b/xds/internal/xdsclient/v2/client.go @@ -28,8 +28,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" @@ -140,7 +140,7 @@ func (v2c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v2c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v2c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) diff --git a/xds/internal/client/v2/client_test.go b/xds/internal/xdsclient/v2/client_test.go similarity index 99% rename from xds/internal/client/v2/client_test.go rename to xds/internal/xdsclient/v2/client_test.go index 371375f3ee5c..138c9a161695 100644 --- a/xds/internal/client/v2/client_test.go +++ b/xds/internal/xdsclient/v2/client_test.go @@ -37,9 +37,9 @@ import ( "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/resolver" "google.golang.org/grpc/resolver/manual" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" "google.golang.org/protobuf/testing/protocmp" xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" @@ -339,6 +339,8 @@ func (t *testUpdateReceiver) NewEndpoints(d map[string]xdsclient.EndpointsUpdate t.newUpdate(xdsclient.EndpointsResource, dd, metadata) } +func (t *testUpdateReceiver) NewConnectionError(error) {} + func (t *testUpdateReceiver) newUpdate(rType xdsclient.ResourceType, d map[string]interface{}, metadata xdsclient.UpdateMetadata) { t.f(rType, d, metadata) } diff --git a/xds/internal/client/v2/eds_test.go b/xds/internal/xdsclient/v2/eds_test.go similarity index 99% rename from xds/internal/client/v2/eds_test.go rename to xds/internal/xdsclient/v2/eds_test.go index 08e75d373017..7338ec80a5e5 100644 --- a/xds/internal/client/v2/eds_test.go +++ b/xds/internal/xdsclient/v2/eds_test.go @@ -28,9 +28,9 @@ import ( anypb "github.com/golang/protobuf/ptypes/any" "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/xds/internal" - xdsclient "google.golang.org/grpc/xds/internal/client" xtestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" ) var ( diff --git a/xds/internal/client/v2/lds_test.go b/xds/internal/xdsclient/v2/lds_test.go similarity index 99% rename from xds/internal/client/v2/lds_test.go rename to xds/internal/xdsclient/v2/lds_test.go index 22fa35d5e51e..e6f1d8f1cf02 100644 --- a/xds/internal/client/v2/lds_test.go +++ b/xds/internal/xdsclient/v2/lds_test.go @@ -26,7 +26,7 @@ import ( v2xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - xdsclient "google.golang.org/grpc/xds/internal/client" + "google.golang.org/grpc/xds/internal/xdsclient" ) // TestLDSHandleResponse starts a fake xDS server, makes a ClientConn to it, diff --git a/xds/internal/client/v2/loadreport.go b/xds/internal/xdsclient/v2/loadreport.go similarity index 98% rename from xds/internal/client/v2/loadreport.go rename to xds/internal/xdsclient/v2/loadreport.go index 17ea8c8d4c43..77db5eb9d8d6 100644 --- a/xds/internal/client/v2/loadreport.go +++ b/xds/internal/xdsclient/v2/loadreport.go @@ -27,7 +27,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" v2corepb "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" v2endpointpb "github.com/envoyproxy/go-control-plane/envoy/api/v2/endpoint" diff --git a/xds/internal/client/v2/rds_test.go b/xds/internal/xdsclient/v2/rds_test.go similarity index 99% rename from xds/internal/client/v2/rds_test.go rename to xds/internal/xdsclient/v2/rds_test.go index 12495428bf95..745308c4e5b6 100644 --- a/xds/internal/client/v2/rds_test.go +++ b/xds/internal/xdsclient/v2/rds_test.go @@ -27,8 +27,8 @@ import ( xdspb "github.com/envoyproxy/go-control-plane/envoy/api/v2" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/testutils/fakeserver" + "google.golang.org/grpc/xds/internal/xdsclient" ) // doLDS makes a LDS watch, and waits for the response and ack to finish. diff --git a/xds/internal/client/v3/client.go b/xds/internal/xdsclient/v3/client.go similarity index 98% rename from xds/internal/client/v3/client.go rename to xds/internal/xdsclient/v3/client.go index 200da2ac7d73..6088189f97f1 100644 --- a/xds/internal/client/v3/client.go +++ b/xds/internal/xdsclient/v3/client.go @@ -29,8 +29,8 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" - xdsclient "google.golang.org/grpc/xds/internal/client" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3adsgrpc "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" @@ -140,7 +140,7 @@ func (v3c *client) RecvResponse(s grpc.ClientStream) (proto.Message, error) { resp, err := stream.Recv() if err != nil { - // TODO: call watch callbacks with error when stream is broken. + v3c.parent.NewConnectionError(err) return nil, fmt.Errorf("xds: stream.Recv() failed: %v", err) } v3c.logger.Infof("ADS response received, type: %v", resp.GetTypeUrl()) diff --git a/xds/internal/client/v3/loadreport.go b/xds/internal/xdsclient/v3/loadreport.go similarity index 98% rename from xds/internal/client/v3/loadreport.go rename to xds/internal/xdsclient/v3/loadreport.go index 1de0ccf57503..147751baab03 100644 --- a/xds/internal/client/v3/loadreport.go +++ b/xds/internal/xdsclient/v3/loadreport.go @@ -27,7 +27,7 @@ import ( "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/xds/internal/client/load" + "google.golang.org/grpc/xds/internal/xdsclient/load" v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3endpointpb "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" diff --git a/xds/internal/client/watchers.go b/xds/internal/xdsclient/watchers.go similarity index 97% rename from xds/internal/client/watchers.go rename to xds/internal/xdsclient/watchers.go index 446f5cca9973..e26ed360308a 100644 --- a/xds/internal/client/watchers.go +++ b/xds/internal/xdsclient/watchers.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "fmt" @@ -66,6 +66,17 @@ func (wi *watchInfo) newUpdate(update interface{}) { wi.c.scheduleCallback(wi, update, nil) } +func (wi *watchInfo) newError(err error) { + wi.mu.Lock() + defer wi.mu.Unlock() + if wi.state == watchInfoStateCanceled { + return + } + wi.state = watchInfoStateRespReceived + wi.expiryTimer.Stop() + wi.sendErrorLocked(err) +} + func (wi *watchInfo) resourceNotFound() { wi.mu.Lock() defer wi.mu.Unlock() diff --git a/xds/internal/client/watchers_cluster_test.go b/xds/internal/xdsclient/watchers_cluster_test.go similarity index 89% rename from xds/internal/client/watchers_cluster_test.go rename to xds/internal/xdsclient/watchers_cluster_test.go index c9837cd51978..0ded36445994 100644 --- a/xds/internal/client/watchers_cluster_test.go +++ b/xds/internal/xdsclient/watchers_cluster_test.go @@ -18,10 +18,11 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -66,7 +67,7 @@ func (s) TestClusterWatch(t *testing.T) { wantUpdate := ClusterUpdate{ClusterName: testEDSName} client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -75,7 +76,7 @@ func (s) TestClusterWatch(t *testing.T) { testCDSName: wantUpdate, "randomName": {}, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -131,7 +132,7 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { wantUpdate := ClusterUpdate{ClusterName: testEDSName} client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -140,7 +141,7 @@ func (s) TestClusterTwoWatchSameResourceName(t *testing.T) { cancelLastWatch() client.NewClusters(map[string]ClusterUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count-1; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -208,11 +209,11 @@ func (s) TestClusterThreeWatchDifferentResourceName(t *testing.T) { }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -249,7 +250,7 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { client.NewClusters(map[string]ClusterUpdate{ testCDSName: wantUpdate, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -265,7 +266,7 @@ func (s) TestClusterWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -349,7 +350,7 @@ func (s) TestClusterWatchExpiryTimerStop(t *testing.T) { client.NewClusters(map[string]ClusterUpdate{ testCDSName: wantUpdate, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -408,10 +409,10 @@ func (s) TestClusterResourceRemoved(t *testing.T) { testCDSName + "1": wantUpdate1, testCDSName + "2": wantUpdate2, }, UpdateMetadata{}) - if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh1, wantUpdate1, nil); err != nil { t.Fatal(err) } - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } @@ -424,7 +425,7 @@ func (s) TestClusterResourceRemoved(t *testing.T) { } // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } @@ -439,7 +440,43 @@ func (s) TestClusterResourceRemoved(t *testing.T) { } // Watcher 2 should get the same update again. - if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2); err != nil { + if err := verifyClusterUpdate(ctx, clusterUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestClusterWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestClusterWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + clusterUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchCluster(testCDSName, func(update ClusterUpdate, err error) { + clusterUpdateCh.Send(clusterUpdateErr{u: update, err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ClusterResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewClusters(map[string]ClusterUpdate{testCDSName: {}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyClusterUpdate(ctx, clusterUpdateCh, ClusterUpdate{}, nil); err != nil { t.Fatal(err) } } diff --git a/xds/internal/client/watchers_endpoints_test.go b/xds/internal/xdsclient/watchers_endpoints_test.go similarity index 88% rename from xds/internal/client/watchers_endpoints_test.go rename to xds/internal/xdsclient/watchers_endpoints_test.go index bff4544d2679..9b5000b8bb9f 100644 --- a/xds/internal/client/watchers_endpoints_test.go +++ b/xds/internal/xdsclient/watchers_endpoints_test.go @@ -18,10 +18,11 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -84,7 +85,7 @@ func (s) TestEndpointsWatch(t *testing.T) { wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -150,7 +151,7 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -159,7 +160,7 @@ func (s) TestEndpointsTwoWatchSameResourceName(t *testing.T) { cancelLastWatch() client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count-1; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -227,11 +228,11 @@ func (s) TestEndpointsThreeWatchDifferentResourceName(t *testing.T) { }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -266,7 +267,7 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { wantUpdate := EndpointsUpdate{Localities: []Locality{testLocalities[0]}} client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -282,7 +283,7 @@ func (s) TestEndpointsWatchAfterCache(t *testing.T) { } // New watch should receives the update. - if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate); err != nil { + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -332,3 +333,39 @@ func (s) TestEndpointsWatchExpiryTimer(t *testing.T) { t.Fatalf("unexpected endpointsUpdate: (%v, %v), want: (EndpointsUpdate{}, nil)", gotUpdate.u, gotUpdate.err) } } + +// TestEndpointsWatchNACKError covers the case that an update is NACK'ed, and +// the watcher should also receive the error. +func (s) TestEndpointsWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + endpointsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchEndpoints(testCDSName, func(update EndpointsUpdate, err error) { + endpointsUpdateCh.Send(endpointsUpdateErr{u: update, err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[EndpointsResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewEndpoints(map[string]EndpointsUpdate{testCDSName: {}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyEndpointsUpdate(ctx, endpointsUpdateCh, EndpointsUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} diff --git a/xds/internal/client/watchers_listener_test.go b/xds/internal/xdsclient/watchers_listener_test.go similarity index 84% rename from xds/internal/client/watchers_listener_test.go rename to xds/internal/xdsclient/watchers_listener_test.go index fdd4ebd163fa..f7a1169bee39 100644 --- a/xds/internal/client/watchers_listener_test.go +++ b/xds/internal/xdsclient/watchers_listener_test.go @@ -18,10 +18,11 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "google.golang.org/grpc/internal/testutils" @@ -64,7 +65,7 @@ func (s) TestLDSWatch(t *testing.T) { wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -73,7 +74,7 @@ func (s) TestLDSWatch(t *testing.T) { testLDSName: wantUpdate, "randomName": {}, }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -132,7 +133,7 @@ func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -141,7 +142,7 @@ func (s) TestLDSTwoWatchSameResourceName(t *testing.T) { cancelLastWatch() client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count-1; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -210,11 +211,11 @@ func (s) TestLDSThreeWatchDifferentResourceName(t *testing.T) { }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -249,7 +250,7 @@ func (s) TestLDSWatchAfterCache(t *testing.T) { wantUpdate := ListenerUpdate{RouteConfigName: testRDSName} client.NewListeners(map[string]ListenerUpdate{testLDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -265,7 +266,7 @@ func (s) TestLDSWatchAfterCache(t *testing.T) { } // New watch should receive the update. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -323,10 +324,10 @@ func (s) TestLDSResourceRemoved(t *testing.T) { testLDSName + "1": wantUpdate1, testLDSName + "2": wantUpdate2, }, UpdateMetadata{}) - if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh1, wantUpdate1, nil); err != nil { t.Fatal(err) } - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } @@ -339,7 +340,7 @@ func (s) TestLDSResourceRemoved(t *testing.T) { } // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } @@ -354,7 +355,43 @@ func (s) TestLDSResourceRemoved(t *testing.T) { } // Watcher 2 should get the same update again. - if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2); err != nil { + if err := verifyListenerUpdate(ctx, ldsUpdateCh2, wantUpdate2, nil); err != nil { + t.Fatal(err) + } +} + +// TestListenerWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestListenerWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + ldsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchListener(testLDSName, func(update ListenerUpdate, err error) { + ldsUpdateCh.Send(ldsUpdateErr{u: update, err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[ListenerResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewListeners(map[string]ListenerUpdate{testLDSName: {}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyListenerUpdate(ctx, ldsUpdateCh, ListenerUpdate{}, wantError); err != nil { t.Fatal(err) } } diff --git a/xds/internal/client/watchers_route_test.go b/xds/internal/xdsclient/watchers_route_test.go similarity index 86% rename from xds/internal/client/watchers_route_test.go rename to xds/internal/xdsclient/watchers_route_test.go index 41640b85b574..5e9b51da9d22 100644 --- a/xds/internal/client/watchers_route_test.go +++ b/xds/internal/xdsclient/watchers_route_test.go @@ -18,10 +18,11 @@ * */ -package client +package xdsclient import ( "context" + "fmt" "testing" "github.com/google/go-cmp/cmp" @@ -73,7 +74,7 @@ func (s) TestRDSWatch(t *testing.T) { }, } client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -146,7 +147,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { } client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -155,7 +156,7 @@ func (s) TestRDSTwoWatchSameResourceName(t *testing.T) { cancelLastWatch() client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) for i := 0; i < count-1; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate, nil); err != nil { t.Fatal(err) } } @@ -237,11 +238,11 @@ func (s) TestRDSThreeWatchDifferentResourceName(t *testing.T) { }, UpdateMetadata{}) for i := 0; i < count; i++ { - if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateChs[i], wantUpdate1, nil); err != nil { t.Fatal(err) } } - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh2, wantUpdate2, nil); err != nil { t.Fatal(err) } } @@ -283,7 +284,7 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { }, } client.NewRouteConfigs(map[string]RouteConfigUpdate{testRDSName: wantUpdate}, UpdateMetadata{}) - if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate); err != nil { + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, wantUpdate, nil); err != nil { t.Fatal(err) } @@ -310,3 +311,39 @@ func (s) TestRDSWatchAfterCache(t *testing.T) { t.Errorf("unexpected RouteConfigUpdate: %v, %v, want channel recv timeout", u, err) } } + +// TestRouteWatchNACKError covers the case that an update is NACK'ed, and the +// watcher should also receive the error. +func (s) TestRouteWatchNACKError(t *testing.T) { + apiClientCh, cleanup := overrideNewAPIClient() + defer cleanup() + + client, err := newWithConfig(clientOpts(testXDSServer, false)) + if err != nil { + t.Fatalf("failed to create client: %v", err) + } + defer client.Close() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + c, err := apiClientCh.Receive(ctx) + if err != nil { + t.Fatalf("timeout when waiting for API client to be created: %v", err) + } + apiClient := c.(*testAPIClient) + + rdsUpdateCh := testutils.NewChannel() + cancelWatch := client.WatchRouteConfig(testCDSName, func(update RouteConfigUpdate, err error) { + rdsUpdateCh.Send(rdsUpdateErr{u: update, err: err}) + }) + defer cancelWatch() + if _, err := apiClient.addWatches[RouteConfigResource].Receive(ctx); err != nil { + t.Fatalf("want new watch to start, got error %v", err) + } + + wantError := fmt.Errorf("testing error") + client.NewRouteConfigs(map[string]RouteConfigUpdate{testCDSName: {}}, UpdateMetadata{ErrState: &UpdateErrorMetadata{Err: wantError}}) + if err := verifyRouteConfigUpdate(ctx, rdsUpdateCh, RouteConfigUpdate{}, wantError); err != nil { + t.Fatal(err) + } +} diff --git a/xds/internal/client/xds.go b/xds/internal/xdsclient/xds.go similarity index 89% rename from xds/internal/client/xds.go rename to xds/internal/xdsclient/xds.go index b95224113237..79c2efcd2cbc 100644 --- a/xds/internal/client/xds.go +++ b/xds/internal/xdsclient/xds.go @@ -16,7 +16,7 @@ * */ -package client +package xdsclient import ( "errors" @@ -178,7 +178,7 @@ func validateHTTPFilterConfig(cfg *anypb.Any, lds, optional bool) (httpfilter.Fi } func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilter.FilterConfig, error) { - if !env.FaultInjectionSupport || len(cfgs) == 0 { + if len(cfgs) == 0 { return nil, nil } m := make(map[string]httpfilter.FilterConfig) @@ -207,10 +207,6 @@ func processHTTPFilterOverrides(cfgs map[string]*anypb.Any) (map[string]httpfilt } func processHTTPFilters(filters []*v3httppb.HttpFilter, server bool) ([]HTTPFilter, error) { - if !env.FaultInjectionSupport { - return nil, nil - } - ret := make([]HTTPFilter, 0, len(filters)) seenNames := make(map[string]bool, len(filters)) for _, filter := range filters { @@ -272,13 +268,6 @@ func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, err Port: strconv.Itoa(int(sockAddr.GetPortValue())), }, } - chains := lis.GetFilterChains() - if def := lis.GetDefaultFilterChain(); def != nil { - chains = append(chains, def) - } - if err := validateNetworkFilterChains(chains); err != nil { - return nil, err - } fcMgr, err := NewFilterChainManager(lis) if err != nil { @@ -288,59 +277,63 @@ func processServerSideListener(lis *v3listenerpb.Listener) (*ListenerUpdate, err return lu, nil } -func validateNetworkFilterChains(filterChains []*v3listenerpb.FilterChain) error { - for _, filterChain := range filterChains { - seenNames := make(map[string]bool, len(filterChain.GetFilters())) - seenHCM := false - for _, filter := range filterChain.GetFilters() { - name := filter.GetName() - if name == "" { - return fmt.Errorf("filter chain {%+v} is missing name field in filter: {%+v}", filterChain, filter) +func processNetworkFilters(filters []*v3listenerpb.Filter) ([]HTTPFilter, error) { + seenNames := make(map[string]bool, len(filters)) + seenHCM := false + var httpFilters []HTTPFilter + for _, filter := range filters { + name := filter.GetName() + if name == "" { + return nil, fmt.Errorf("network filters {%+v} is missing name field in filter: {%+v}", filters, filter) + } + if seenNames[name] { + return nil, fmt.Errorf("network filters {%+v} has duplicate filter name %q", filters, name) + } + seenNames[name] = true + + // Network filters have a oneof field named `config_type` where we + // only support `TypedConfig` variant. + switch typ := filter.GetConfigType().(type) { + case *v3listenerpb.Filter_TypedConfig: + // The typed_config field has an `anypb.Any` proto which could + // directly contain the serialized bytes of the actual filter + // configuration, or it could be encoded as a `TypedStruct`. + // TODO: Add support for `TypedStruct`. + tc := filter.GetTypedConfig() + + // The only network filter that we currently support is the v3 + // HttpConnectionManager. So, we can directly check the type_url + // and unmarshal the config. + // TODO: Implement a registry of supported network filters (like + // we have for HTTP filters), when we have to support network + // filters other than HttpConnectionManager. + if tc.GetTypeUrl() != version.V3HTTPConnManagerURL { + return nil, fmt.Errorf("network filters {%+v} has unsupported network filter %q in filter {%+v}", filters, tc.GetTypeUrl(), filter) } - if seenNames[name] { - return fmt.Errorf("filter chain {%+v} has duplicate filter name %q", filterChain, name) + hcm := &v3httppb.HttpConnectionManager{} + if err := ptypes.UnmarshalAny(tc, hcm); err != nil { + return nil, fmt.Errorf("network filters {%+v} failed unmarshaling of network filter {%+v}: %v", filters, filter, err) } - seenNames[name] = true - - // Network filters have a oneof field named `config_type` where we - // only support `TypedConfig` variant. - switch typ := filter.GetConfigType().(type) { - case *v3listenerpb.Filter_TypedConfig: - // The typed_config field has an `anypb.Any` proto which could - // directly contain the serialized bytes of the actual filter - // configuration, or it could be encoded as a `TypedStruct`. - // TODO: Add support for `TypedStruct`. - tc := filter.GetTypedConfig() - - // The only network filter that we currently support is the v3 - // HttpConnectionManager. So, we can directly check the type_url - // and unmarshal the config. - // TODO: Implement a registry of supported network filters (like - // we have for HTTP filters), when we have to support network - // filters other than HttpConnectionManager. - if tc.GetTypeUrl() != version.V3HTTPConnManagerURL { - return fmt.Errorf("filter chain {%+v} has unsupported network filter %q in filter {%+v}", filterChain, tc.GetTypeUrl(), filter) - } - hcm := &v3httppb.HttpConnectionManager{} - if err := ptypes.UnmarshalAny(tc, hcm); err != nil { - return fmt.Errorf("filter chain {%+v} failed unmarshaling of network filter {%+v}: %v", filterChain, filter, err) - } - // We currently don't support HTTP filters on the server-side. - // We will be adding support for it in the future. So, we want - // to make sure that the http_filters configuration is valid. - if _, err := processHTTPFilters(hcm.GetHttpFilters(), true); err != nil { - return err - } + // "Any filters after HttpConnectionManager should be ignored during + // connection processing but still be considered for validity. + // HTTPConnectionManager must have valid http_filters." - A36 + filters, err := processHTTPFilters(hcm.GetHttpFilters(), true) + if err != nil { + return nil, fmt.Errorf("network filters {%+v} had invalid server side HTTP Filters {%+v}", filters, hcm.GetHttpFilters()) + } + if !seenHCM { + // TODO: Implement terminal filter logic, as per A36. + httpFilters = filters seenHCM = true - default: - return fmt.Errorf("filter chain {%+v} has unsupported config_type %T in filter %s", filterChain, typ, filter.GetName()) } - } - if !seenHCM { - return fmt.Errorf("filter chain {%+v} missing HttpConnectionManager filter", filterChain) + default: + return nil, fmt.Errorf("network filters {%+v} has unsupported config_type %T in filter %s", filters, typ, filter.GetName()) } } - return nil + if !seenHCM { + return nil, fmt.Errorf("network filters {%+v} missing HttpConnectionManager filter", filters) + } + return httpFilters, nil } // UnmarshalRouteConfig processes resources received in an RDS response, @@ -500,6 +493,16 @@ func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, route.WeightedClusters = make(map[string]WeightedCluster) action := r.GetRoute() + + // Hash Policies are only applicable for a Ring Hash LB. + if env.RingHashSupport { + hp, err := hashPoliciesProtoToSlice(action.HashPolicy, logger) + if err != nil { + return nil, err + } + route.HashPolicies = hp + } + switch a := action.GetClusterSpecifier().(type) { case *v3routepb.RouteAction_Cluster: route.WeightedClusters[a.Cluster] = WeightedCluster{Weight: 1} @@ -561,6 +564,39 @@ func routesProtoToSlice(routes []*v3routepb.Route, logger *grpclog.PrefixLogger, return routesRet, nil } +func hashPoliciesProtoToSlice(policies []*v3routepb.RouteAction_HashPolicy, logger *grpclog.PrefixLogger) ([]*HashPolicy, error) { + var hashPoliciesRet []*HashPolicy + for _, p := range policies { + policy := HashPolicy{Terminal: p.Terminal} + switch p.GetPolicySpecifier().(type) { + case *v3routepb.RouteAction_HashPolicy_Header_: + policy.HashPolicyType = HashPolicyTypeHeader + policy.HeaderName = p.GetHeader().GetHeaderName() + if rr := p.GetHeader().GetRegexRewrite(); rr != nil { + regex := rr.GetPattern().GetRegex() + re, err := regexp.Compile(regex) + if err != nil { + return nil, fmt.Errorf("hash policy %+v contains an invalid regex %q", p, regex) + } + policy.Regex = re + policy.RegexSubstitution = rr.GetSubstitution() + } + case *v3routepb.RouteAction_HashPolicy_FilterState_: + if p.GetFilterState().GetKey() != "io.grpc.channel_id" { + logger.Infof("hash policy %+v contains an invalid key for filter state policy %q", p, p.GetFilterState().GetKey()) + continue + } + policy.HashPolicyType = HashPolicyTypeChannelID + default: + logger.Infof("hash policy %T is an unsupported hash policy", p.GetPolicySpecifier()) + continue + } + + hashPoliciesRet = append(hashPoliciesRet, &policy) + } + return hashPoliciesRet, nil +} + // UnmarshalCluster processes resources received in an CDS response, validates // them, and transforms them into a native struct which contains only fields we // are interested in. @@ -776,9 +812,6 @@ func securityConfigFromCommonTLSContext(common *v3tlspb.CommonTlsContext) (*Secu // the received cluster resource. Returns nil if no CircuitBreakers or no // Thresholds in CircuitBreakers. func circuitBreakersFromCluster(cluster *v3clusterpb.Cluster) *uint32 { - if !env.CircuitBreakingSupport { - return nil - } for _, threshold := range cluster.GetCircuitBreakers().GetThresholds() { if threshold.GetPriority() != v3corepb.RoutingPriority_DEFAULT { continue diff --git a/xds/internal/client/tests/client_test.go b/xds/internal/xdsclient/xdsclient_test.go similarity index 92% rename from xds/internal/client/tests/client_test.go rename to xds/internal/xdsclient/xdsclient_test.go index 755f0e05ea45..7b33c145f3bd 100644 --- a/xds/internal/client/tests/client_test.go +++ b/xds/internal/xdsclient/xdsclient_test.go @@ -18,7 +18,7 @@ * */ -package tests_test +package xdsclient_test import ( "testing" @@ -27,11 +27,11 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/internal/grpctest" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 API client. "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/version" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 API client. ) type s struct { diff --git a/xds/server.go b/xds/server.go index 3a2b629ae986..aad05d81d116 100644 --- a/xds/server.go +++ b/xds/server.go @@ -33,19 +33,18 @@ import ( "google.golang.org/grpc/internal/buffer" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcsync" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" "google.golang.org/grpc/xds/internal/server" + "google.golang.org/grpc/xds/internal/xdsclient" ) const serverPrefix = "[xds-server %p] " var ( // These new functions will be overridden in unit tests. - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return xdsclient.New() } - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return grpc.NewServer(opts...) } @@ -58,21 +57,14 @@ func prefixLogger(p *GRPCServer) *internalgrpclog.PrefixLogger { return internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(serverPrefix, p)) } -// xdsClientInterface contains methods from xdsClient.Client which are used by -// the server. This is useful for overriding in unit tests. -type xdsClientInterface interface { - WatchListener(string, func(xdsclient.ListenerUpdate, error)) func() - BootstrapConfig() *bootstrap.Config - Close() -} - -// grpcServerInterface contains methods from grpc.Server which are used by the +// grpcServer contains methods from grpc.Server which are used by the // GRPCServer type here. This is useful for overriding in unit tests. -type grpcServerInterface interface { +type grpcServer interface { RegisterService(*grpc.ServiceDesc, interface{}) Serve(net.Listener) error Stop() GracefulStop() + GetServiceInfo() map[string]grpc.ServiceInfo } // GRPCServer wraps a gRPC server and provides server-side xDS functionality, by @@ -80,7 +72,7 @@ type grpcServerInterface interface { // grpc.ServiceRegistrar interface and can be passed to service registration // functions in IDL generated code. type GRPCServer struct { - gs grpcServerInterface + gs grpcServer quit *grpcsync.Event logger *internalgrpclog.PrefixLogger xdsCredsInUse bool @@ -90,7 +82,7 @@ type GRPCServer struct { // beginning of Serve(), where we have to decide if we have to create a // client or use an existing one. clientMu sync.Mutex - xdsC xdsClientInterface + xdsC xdsclient.XDSClient } // NewGRPCServer creates an xDS-enabled gRPC server using the passed in opts. @@ -131,8 +123,8 @@ func NewGRPCServer(opts ...grpc.ServerOption) *GRPCServer { func handleServerOptions(opts []grpc.ServerOption) *serverOptions { so := &serverOptions{} for _, opt := range opts { - if o, ok := opt.(serverOption); ok { - o.applyServerOption(so) + if o, ok := opt.(*serverOption); ok { + o.apply(so) } } return so @@ -145,6 +137,12 @@ func (s *GRPCServer) RegisterService(sd *grpc.ServiceDesc, ss interface{}) { s.gs.RegisterService(sd, ss) } +// GetServiceInfo returns a map from service names to ServiceInfo. +// Service names include the package names, in the form of .. +func (s *GRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + return s.gs.GetServiceInfo() +} + // initXDSClient creates a new xdsClient if there is no existing one available. func (s *GRPCServer) initXDSClient() error { s.clientMu.Lock() @@ -154,6 +152,12 @@ func (s *GRPCServer) initXDSClient() error { return nil } + newXDSClient := newXDSClient + if s.opts.bootstrapContents != nil { + newXDSClient = func() (xdsclient.XDSClient, error) { + return xdsclient.NewClientWithBootstrapContents(s.opts.bootstrapContents) + } + } client, err := newXDSClient() if err != nil { return fmt.Errorf("xds: failed to create xds-client: %v", err) @@ -181,7 +185,6 @@ func (s *GRPCServer) Serve(lis net.Listener) error { if err := s.initXDSClient(); err != nil { return err } - cfg := s.xdsC.BootstrapConfig() if cfg == nil { return errors.New("bootstrap configuration is empty") diff --git a/xds/server_options.go b/xds/server_options.go index 44b7b374fd00..0918c097a3e5 100644 --- a/xds/server_options.go +++ b/xds/server_options.go @@ -25,31 +25,20 @@ import ( iserver "google.golang.org/grpc/xds/internal/server" ) -// ServingModeCallback returns a grpc.ServerOption which allows users to -// register a callback to get notified about serving mode changes. -func ServingModeCallback(cb ServingModeCallbackFunc) grpc.ServerOption { - return &smcOption{cb: cb} -} - -type serverOption interface { - applyServerOption(*serverOptions) +type serverOptions struct { + modeCallback ServingModeCallbackFunc + bootstrapContents []byte } -// smcOption is a server option containing a callback to be invoked when the -// serving mode changes. -type smcOption struct { - // Embedding the empty server option makes it safe to pass it to - // grpc.NewServer(). +type serverOption struct { grpc.EmptyServerOption - cb ServingModeCallbackFunc -} - -func (s *smcOption) applyServerOption(o *serverOptions) { - o.modeCallback = s.cb + apply func(*serverOptions) } -type serverOptions struct { - modeCallback ServingModeCallbackFunc +// ServingModeCallback returns a grpc.ServerOption which allows users to +// register a callback to get notified about serving mode changes. +func ServingModeCallback(cb ServingModeCallbackFunc) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.modeCallback = cb }} } // ServingMode indicates the current mode of operation of the server. @@ -82,3 +71,15 @@ type ServingModeChangeArgs struct { // not-serving mode. Err error } + +// BootstrapContentsForTesting returns a grpc.ServerOption which allows users +// to inject a bootstrap configuration used by only this server, instead of the +// global configuration from the environment variables. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +func BootstrapContentsForTesting(contents []byte) grpc.ServerOption { + return &serverOption{apply: func(o *serverOptions) { o.bootstrapContents = contents }} +} diff --git a/xds/server_test.go b/xds/server_test.go index e16ac36b01f2..00b8518fa9d9 100644 --- a/xds/server_test.go +++ b/xds/server_test.go @@ -32,6 +32,7 @@ import ( v3corepb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" v3listenerpb "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + v3httppb "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" v3tlspb "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -39,10 +40,10 @@ import ( "google.golang.org/grpc/credentials/xds" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/testutils" - xdsclient "google.golang.org/grpc/xds/internal/client" - "google.golang.org/grpc/xds/internal/client/bootstrap" xdstestutils "google.golang.org/grpc/xds/internal/testutils" "google.golang.org/grpc/xds/internal/testutils/fakeclient" + "google.golang.org/grpc/xds/internal/xdsclient" + "google.golang.org/grpc/xds/internal/xdsclient/bootstrap" ) const ( @@ -86,6 +87,10 @@ func (f *fakeGRPCServer) GracefulStop() { f.gracefulStopCh.Send(nil) } +func (f *fakeGRPCServer) GetServiceInfo() map[string]grpc.ServiceInfo { + panic("implement me") +} + func newFakeGRPCServer() *fakeGRPCServer { return &fakeGRPCServer{ done: make(chan struct{}), @@ -133,7 +138,7 @@ func (s) TestNewServer(t *testing.T) { wantServerOpts := len(test.serverOpts) + 2 origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { if got := len(opts); got != wantServerOpts { t.Fatalf("%d ServerOptions passed to grpc.Server, want %d", got, wantServerOpts) } @@ -161,7 +166,7 @@ func (s) TestRegisterService(t *testing.T) { fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } defer func() { newGRPCServer = origNewGRPCServer }() s := NewGRPCServer() @@ -247,7 +252,7 @@ func (p *fakeProvider) Close() { func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(&bootstrap.Config{ BalancerName: "dummyBalancer", @@ -262,7 +267,7 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { fs := newFakeGRPCServer() origNewGRPCServer := newGRPCServer - newGRPCServer = func(opts ...grpc.ServerOption) grpcServerInterface { return fs } + newGRPCServer = func(opts ...grpc.ServerOption) grpcServer { return fs } return fs, clientCh, func() { newXDSClient = origNewXDSClient @@ -277,7 +282,7 @@ func setupOverrides() (*fakeGRPCServer, *testutils.Channel, func()) { func setupOverridesForXDSCreds(includeCertProviderCfg bool) (*testutils.Channel, func()) { clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() bc := &bootstrap.Config{ BalancerName: "dummyBalancer", @@ -544,7 +549,7 @@ func (s) TestServeBootstrapConfigInvalid(t *testing.T) { // xdsClient with the specified bootstrap configuration. clientCh := testutils.NewChannel() origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { c := fakeclient.NewClient() c.SetBootstrapConfig(test.bootstrapConfig) clientCh.Send(c) @@ -587,7 +592,7 @@ func (s) TestServeBootstrapConfigInvalid(t *testing.T) { // verifies that Server() exits with a non-nil error. func (s) TestServeNewClientFailure(t *testing.T) { origNewXDSClient := newXDSClient - newXDSClient = func() (xdsClientInterface, error) { + newXDSClient = func() (xdsclient.XDSClient, error) { return nil, errors.New("xdsClient creation failed") } defer func() { newXDSClient = origNewXDSClient }() @@ -679,6 +684,14 @@ func (s) TestHandleListenerUpdate_NoXDSCreds(t *testing.T) { }), }, }, + Filters: []*v3listenerpb.Filter{ + { + Name: "filter-1", + ConfigType: &v3listenerpb.Filter_TypedConfig{ + TypedConfig: testutils.MarshalAny(&v3httppb.HttpConnectionManager{}), + }, + }, + }, }, }, }) diff --git a/xds/xds.go b/xds/xds.go index 23c88903f40b..864d2e6a2e3a 100644 --- a/xds/xds.go +++ b/xds/xds.go @@ -38,14 +38,15 @@ import ( v3statusgrpc "github.com/envoyproxy/go-control-plane/envoy/service/status/v3" "google.golang.org/grpc" internaladmin "google.golang.org/grpc/internal/admin" + "google.golang.org/grpc/resolver" "google.golang.org/grpc/xds/csds" _ "google.golang.org/grpc/credentials/tls/certprovider/pemfile" // Register the file watcher certificate provider plugin. _ "google.golang.org/grpc/xds/internal/balancer" // Register the balancers. - _ "google.golang.org/grpc/xds/internal/client/v2" // Register the v2 xDS API client. - _ "google.golang.org/grpc/xds/internal/client/v3" // Register the v3 xDS API client. _ "google.golang.org/grpc/xds/internal/httpfilter/fault" // Register the fault injection filter. - _ "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + xdsresolver "google.golang.org/grpc/xds/internal/resolver" // Register the xds_resolver. + _ "google.golang.org/grpc/xds/internal/xdsclient/v2" // Register the v2 xDS API client. + _ "google.golang.org/grpc/xds/internal/xdsclient/v3" // Register the v3 xDS API client. ) func init() { @@ -76,3 +77,16 @@ func init() { return csdss.Close, nil }) } + +// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using +// the provided xds bootstrap config instead of the global configuration from +// the supported environment variables. The resolver.Builder is meant to be +// used in conjunction with the grpc.WithResolvers DialOption. +// +// Testing Only +// +// This function should ONLY be used for testing and may not work with some +// other features, including the CSDS service. +func NewXDSResolverWithConfigForTesting(bootstrapConfig []byte) (resolver.Builder, error) { + return xdsresolver.NewBuilder(bootstrapConfig) +}