diff --git a/.github/workflows/codecov.yaml b/.github/workflows/codecov.yaml index fb6742261..8ae24e36c 100644 --- a/.github/workflows/codecov.yaml +++ b/.github/workflows/codecov.yaml @@ -100,6 +100,6 @@ jobs: bash test/codecov.sh - name: Upload Codecov - uses: codecov/codecov-action@v3 + uses: codecov/codecov-action@v4 with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0034a354c..e398fa346 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -53,7 +53,7 @@ jobs: password: ${{ secrets.POLARIS_DOCKER_PASSWORD }} - name: Build Server - id: build-server + id: build-hub-server env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -63,7 +63,7 @@ jobs: make build-docker IMAGE_TAG=${DOCKER_TAG} - name: Build Prometheus - id: build-prom + id: build-hub-prom env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -72,3 +72,33 @@ jobs: cd release/standalone/docker/prometheus ls -lstrh bash build_docker_prom.sh ${DOCKER_TAG} + + - name: Log in to Tencent CCR + uses: docker/login-action@v1 + with: + registry: ccr.ccs.tencentyun.com + username: ${{ secrets.TENCENT_DOCKER_NAME }} + password: ${{ secrets.TENCENT_DOCKER_PASSWORD }} + + - name: Build Server + id: build-tencent-server + env: + DOCKER_REPOSITORY: ccr.ccs.tencentyun.com/polarismesh + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + DOCKER_TAG: ${{ steps.get_version.outputs.VERSION }} + run: | + ls -lstrh + make build-docker IMAGE_TAG=${DOCKER_TAG} + + - name: Build Prometheus + id: build-tencent-prom + env: + DOCKER_REPOSITORY: ccr.ccs.tencentyun.com/polarismesh + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + DOCKER_TAG: ${{ steps.get_version.outputs.VERSION }} + run: | + cd release/standalone/docker/prometheus + ls -lstrh + bash build_docker_prom.sh ${DOCKER_TAG} diff --git a/apiserver/eurekaserver/access_test.go b/apiserver/eurekaserver/access_test.go index 24964f2a3..c3d3bce08 100644 --- a/apiserver/eurekaserver/access_test.go +++ b/apiserver/eurekaserver/access_test.go @@ -197,7 +197,7 @@ func TestCreateInstance(t *testing.T) { time.Sleep(5 * time.Second) instanceId := fmt.Sprintf("%s_%s_%d", appId, host, startPort) code := eurekaSrv.deregisterInstance(context.Background(), namespace, appId, instanceId, false) - assert.Equal(t, api.ExecuteSuccess, code) + assert.Equal(t, api.ExecuteSuccess, code, fmt.Sprintf("%d", code)) time.Sleep(20 * time.Second) deltaReq := restful.NewRequest(httpRequest) @@ -244,18 +244,17 @@ func Test_EurekaWrite(t *testing.T) { injectRestfulReqPathParameters(t, restfulReq, map[string]string{ ParamAppId: mockIns.AppName, }) - // 这里是异步注册 eurekaSrv.RegisterApplication(restfulReq, restful.NewResponse(mockRsp)) assert.Equal(t, http.StatusNoContent, mockRsp.statusCode) assert.Equal(t, restfulReq.Attribute(statusCodeHeader), uint32(apimodel.Code_ExecuteSuccess)) - time.Sleep(5 * time.Second) - saveIns, err := eurekaSrv.originDiscoverSvr.Cache().GetStore().GetInstance(mockIns.InstanceId) + _ = discoverSuit.CacheMgr().TestUpdate() + saveIns, err := discoverSuit.Storage.GetInstance(mockIns.InstanceId) assert.NoError(t, err) assert.NotNil(t, saveIns) t.Run("UpdateStatus", func(t *testing.T) { - t.Run("StatusUnknown", func(t *testing.T) { + t.Run("01_StatusUnknown", func(t *testing.T) { mockReq := httptest.NewRequest("", fmt.Sprintf("http://127.0.0.1:8761/eureka/v2/apps/%s/%s/status", mockIns.AppName, mockIns.InstanceId), nil) mockReq.PostForm = url.Values{} @@ -278,7 +277,7 @@ func Test_EurekaWrite(t *testing.T) { assert.False(t, saveIns.Isolate()) }) - t.Run("StatusDown", func(t *testing.T) { + t.Run("02_StatusDown", func(t *testing.T) { mockReq := httptest.NewRequest("", fmt.Sprintf("http://127.0.0.1:8761/eureka/v2/apps/%s/%s/status", mockIns.AppName, mockIns.InstanceId), nil) mockReq.PostForm = url.Values{} @@ -301,7 +300,7 @@ func Test_EurekaWrite(t *testing.T) { assert.Equal(t, StatusDown, saveIns.Proto.Metadata[InternalMetadataStatus]) }) - t.Run("StatusUp", func(t *testing.T) { + t.Run("03_StatusUp", func(t *testing.T) { mockReq := httptest.NewRequest("", fmt.Sprintf("http://127.0.0.1:8761/eureka/v2/apps/%s/%s/status", mockIns.AppName, mockIns.InstanceId), nil) mockReq.PostForm = url.Values{} @@ -324,9 +323,9 @@ func Test_EurekaWrite(t *testing.T) { assert.Equal(t, StatusUp, saveIns.Proto.Metadata[InternalMetadataStatus]) }) - t.Run("Polaris_UpdateInstances", func(t *testing.T) { + t.Run("04_Polaris_UpdateInstances", func(t *testing.T) { defer func() { - rsp := discoverSuit.OriginDiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ + rsp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ { Id: wrapperspb.String(mockIns.InstanceId), Isolate: wrapperspb.Bool(false), @@ -334,7 +333,7 @@ func Test_EurekaWrite(t *testing.T) { }) assert.Equal(t, apimodel.Code_ExecuteSuccess, apimodel.Code(rsp.GetCode().GetValue())) }() - rsp := discoverSuit.OriginDiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ + rsp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ { Id: wrapperspb.String(mockIns.InstanceId), Isolate: wrapperspb.Bool(true), @@ -349,8 +348,8 @@ func Test_EurekaWrite(t *testing.T) { assert.Equal(t, StatusOutOfService, saveIns.Proto.Metadata[InternalMetadataStatus]) }) - t.Run("Polaris_UpdateInstancesIsolate", func(t *testing.T) { - rsp := discoverSuit.OriginDiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ + t.Run("05_Polaris_UpdateInstancesIsolate", func(t *testing.T) { + rsp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*service_manage.Instance{ { Id: wrapperspb.String(mockIns.InstanceId), Isolate: wrapperspb.Bool(true), diff --git a/apiserver/eurekaserver/chain.go b/apiserver/eurekaserver/chain.go index 642122037..d9f64e944 100644 --- a/apiserver/eurekaserver/chain.go +++ b/apiserver/eurekaserver/chain.go @@ -35,7 +35,7 @@ type ( func (h *EurekaServer) registerInstanceChain() { svr := h.originDiscoverSvr.(*service.Server) svr.AddInstanceChain(&EurekaInstanceChain{ - s: h.namingServer.Cache().GetStore(), + s: svr.Store(), }) } diff --git a/apiserver/eurekaserver/write.go b/apiserver/eurekaserver/write.go index d474105ba..988a502ca 100644 --- a/apiserver/eurekaserver/write.go +++ b/apiserver/eurekaserver/write.go @@ -32,6 +32,7 @@ import ( "github.com/polarismesh/polaris/common/model" commonstore "github.com/polarismesh/polaris/common/store" "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/service" ) func checkOrBuildNewInstanceId(appId string, instId string, generateUniqueInstId bool) string { @@ -256,7 +257,8 @@ func (h *EurekaServer) updateStatus( }) instanceId = checkOrBuildNewInstanceIdByNamespace(namespace, h.namespace, appId, instanceId, h.generateUniqueInstId) - saveIns, err := h.originDiscoverSvr.Cache().GetStore().GetInstance(instanceId) + svr := h.originDiscoverSvr.(*service.Server) + saveIns, err := svr.Store().GetInstance(instanceId) if err != nil { eurekalog.Error("[EUREKA-SERVER] get instance from store when update status", zap.Error(err)) return uint32(commonstore.StoreCode2APICode(err)) diff --git a/apiserver/nacosserver/core/storage.go b/apiserver/nacosserver/core/storage.go index 30992b245..c7826443f 100644 --- a/apiserver/nacosserver/core/storage.go +++ b/apiserver/nacosserver/core/storage.go @@ -29,7 +29,7 @@ import ( "golang.org/x/sync/singleflight" nacosmodel "github.com/polarismesh/polaris/apiserver/nacosserver/model" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" commontime "github.com/polarismesh/polaris/common/time" @@ -50,7 +50,7 @@ type ( ins []*nacosmodel.Instance, healthyCount int32) *nacosmodel.ServiceInfo ) -func NewNacosDataStorage(cacheMgr *cache.CacheManager) *NacosDataStorage { +func NewNacosDataStorage(cacheMgr cachetypes.CacheManager) *NacosDataStorage { ctx, cancel := context.WithCancel(context.Background()) notifier, notifierFinish := context.WithCancel(context.Background()) store := &NacosDataStorage{ @@ -67,7 +67,7 @@ func NewNacosDataStorage(cacheMgr *cache.CacheManager) *NacosDataStorage { // NacosDataStorage . type NacosDataStorage struct { - cacheMgr *cache.CacheManager + cacheMgr cachetypes.CacheManager ctx context.Context cancel context.CancelFunc @@ -82,7 +82,7 @@ type NacosDataStorage struct { revisions map[string]string } -func (n *NacosDataStorage) Cache() *cache.CacheManager { +func (n *NacosDataStorage) Cache() cachetypes.CacheManager { return n.cacheMgr } @@ -343,7 +343,7 @@ func SelectInstancesWithHealthyProtection(ctx *FilterContext, result *nacosmodel return result } -func ToNacosService(cacheMgr *cache.CacheManager, namespace, service, group string) *nacosmodel.ServiceMetadata { +func ToNacosService(cacheMgr cachetypes.CacheManager, namespace, service, group string) *nacosmodel.ServiceMetadata { ret := &nacosmodel.ServiceMetadata{ ServiceKey: nacosmodel.ServiceKey{ Namespace: namespace, diff --git a/apiserver/nacosserver/v1/config/server.go b/apiserver/nacosserver/v1/config/server.go index f3d6e4f48..71a26efd2 100644 --- a/apiserver/nacosserver/v1/config/server.go +++ b/apiserver/nacosserver/v1/config/server.go @@ -21,7 +21,7 @@ import ( "github.com/polarismesh/polaris/apiserver/nacosserver/core" "github.com/polarismesh/polaris/apiserver/nacosserver/v2/remote" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/config" "github.com/polarismesh/polaris/namespace" ) @@ -45,7 +45,7 @@ type ConfigServer struct { namespaceSvr namespace.NamespaceOperateServer configSvr config.ConfigCenterServer originConfigSvr config.ConfigCenterServer - cacheSvr *cache.CacheManager + cacheSvr cachetypes.CacheManager } func (h *ConfigServer) Initialize(opt *ServerOption) error { diff --git a/apiserver/nacosserver/v1/discover/instance.go b/apiserver/nacosserver/v1/discover/instance.go index 863edcacf..87db26324 100644 --- a/apiserver/nacosserver/v1/discover/instance.go +++ b/apiserver/nacosserver/v1/discover/instance.go @@ -31,6 +31,7 @@ import ( "github.com/polarismesh/polaris/apiserver/nacosserver/model" commonmodel "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/service" ) func (n *DiscoverServer) handleRegister(ctx context.Context, namespace, serviceName string, ins *model.Instance) error { @@ -57,7 +58,8 @@ func (n *DiscoverServer) handleUpdate(ctx context.Context, namespace, serviceNam } specIns.Id = wrapperspb.String(insId) } - saveIns, err := n.discoverSvr.Cache().GetStore().GetInstance(specIns.GetId().GetValue()) + svr := n.discoverSvr.(*service.Server) + saveIns, err := svr.Store().GetInstance(specIns.GetId().GetValue()) if err != nil { return &model.NacosError{ ErrCode: int32(model.ExceptionCode_ServerError), diff --git a/apiserver/nacosserver/v2/config/server.go b/apiserver/nacosserver/v2/config/server.go index 27b420399..09b7196bd 100644 --- a/apiserver/nacosserver/v2/config/server.go +++ b/apiserver/nacosserver/v2/config/server.go @@ -22,7 +22,7 @@ import ( nacospb "github.com/polarismesh/polaris/apiserver/nacosserver/v2/pb" "github.com/polarismesh/polaris/apiserver/nacosserver/v2/remote" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/config" "github.com/polarismesh/polaris/namespace" ) @@ -49,7 +49,7 @@ type ConfigServer struct { namespaceSvr namespace.NamespaceOperateServer configSvr config.ConfigCenterServer originConfigSvr config.ConfigCenterServer - cacheSvr *cache.CacheManager + cacheSvr cachetypes.CacheManager handleRegistry map[string]*remote.RequestHandlerWarrper } diff --git a/apiserver/nacosserver/v2/discover/checker.go b/apiserver/nacosserver/v2/discover/checker.go index b8331b209..cc69b9644 100644 --- a/apiserver/nacosserver/v2/discover/checker.go +++ b/apiserver/nacosserver/v2/discover/checker.go @@ -30,7 +30,7 @@ import ( nacosmodel "github.com/polarismesh/polaris/apiserver/nacosserver/model" "github.com/polarismesh/polaris/apiserver/nacosserver/v2/remote" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" @@ -44,7 +44,7 @@ type Checker struct { discoverSvr service.DiscoverServer healthSvr *healthcheck.Server - cacheMgr *cache.CacheManager + cacheMgr cachetypes.CacheManager connMgr *remote.ConnectionManager clientMgr *ConnectionClientManager @@ -192,6 +192,8 @@ func (c *Checker) runCheck(ctx context.Context) { // BUT: 一个实例 T1 时刻对应长连接为 Conn-1,T2 时刻对应的长连接为 Conn-2,但是在 T1 ~ T2 之间的某个时刻检测发现长连接不存在 // 此时发起一个反注册请求,该请求在 T3 时刻发起,是否会影响 T2 时刻注册上来的实例? func (c *Checker) realCheck() { + svr := c.discoverSvr.(*service.Server) + defer func() { if err := recover(); err != nil { var buf [4086]byte @@ -267,7 +269,7 @@ func (c *Checker) realCheck() { } nacoslog.Info("[NACOS-V2][Checker] batch set instance health_status to unhealthy", zap.Any("instance-ids", ids)) - if err := c.discoverSvr.Cache().GetStore(). + if err := svr.Store(). BatchSetInstanceHealthStatus(ids, model.StatusBoolToInt(false), utils.NewUUID()); err != nil { nacoslog.Error("[NACOS-V2][Checker] batch set instance health_status to unhealthy", zap.Any("instance-ids", ids), zap.Error(err)) @@ -281,7 +283,7 @@ func (c *Checker) realCheck() { } nacoslog.Info("[NACOS-V2][Checker] batch set instance health_status to healty", zap.Any("instance-ids", ids)) - if err := c.discoverSvr.Cache().GetStore(). + if err := svr.Store(). BatchSetInstanceHealthStatus(ids, model.StatusBoolToInt(true), utils.NewUUID()); err != nil { nacoslog.Error("[NACOS-V2][Checker] batch set instance health_status to healty", zap.Any("instance-ids", ids), zap.Error(err)) diff --git a/auth/user/group.go b/auth/user/group.go index 4dcdc301c..d40a6ae7a 100644 --- a/auth/user/group.go +++ b/auth/user/group.go @@ -124,7 +124,7 @@ func (svr *Server) UpdateGroup(ctx context.Context, req *apisecurity.ModifyUserG return errResp } - modifyReq, needUpdate := updateGroupAttribute(ctx, data.UserGroup, req) + modifyReq, needUpdate := UpdateGroupAttribute(ctx, data.UserGroup, req) if !needUpdate { log.Info("update group data no change, no need update", utils.RequestID(ctx), zap.String("group", req.String())) @@ -385,8 +385,8 @@ func (svr *Server) checkUpdateGroup(ctx context.Context, req *apisecurity.Modify return nil } -// updateGroupAttribute 更新计算用户组更新时的结构体数据,并判断是否需要执行更新操作 -func updateGroupAttribute(ctx context.Context, old *model.UserGroup, newUser *apisecurity.ModifyUserGroup) ( +// UpdateGroupAttribute 更新计算用户组更新时的结构体数据,并判断是否需要执行更新操作 +func UpdateGroupAttribute(ctx context.Context, old *model.UserGroup, newUser *apisecurity.ModifyUserGroup) ( *model.ModifyUserGroup, bool) { var ( needUpdate bool diff --git a/auth/user/user_test.go b/auth/user/user_test.go index 93219e528..907f78f68 100644 --- a/auth/user/user_test.go +++ b/auth/user/user_test.go @@ -24,6 +24,7 @@ import ( "github.com/golang/mock/gomock" "github.com/golang/protobuf/ptypes/wrappers" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" "github.com/stretchr/testify/assert" @@ -318,6 +319,32 @@ func Test_server_CreateUsers(t *testing.T) { }) } +func Test_server_Login(t *testing.T) { + + userTest := newUserTest(t) + defer userTest.Clean() + + t.Run("正常登陆", func(t *testing.T) { + rsp := userTest.svr.Login(&apisecurity.LoginRequest{ + Name: &wrappers.StringValue{Value: userTest.users[0].Name}, + Password: &wrappers.StringValue{Value: "polaris"}, + }) + + assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + + t.Run("错误的密码", func(t *testing.T) { + rsp := userTest.svr.Login(&apisecurity.LoginRequest{ + Name: &wrappers.StringValue{Value: userTest.users[0].Name}, + Password: &wrappers.StringValue{Value: "polaris_123"}, + }) + + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_NotAllowedAccess), rsp.GetCode().GetValue()) + assert.Contains(t, rsp.GetInfo().GetValue(), model.ErrorWrongUsernameOrPassword.Error()) + }) +} + func Test_server_UpdateUser(t *testing.T) { userTest := newUserTest(t) diff --git a/bootstrap/server.go b/bootstrap/server.go index 6ae2e0265..b3c17f314 100644 --- a/bootstrap/server.go +++ b/bootstrap/server.go @@ -569,7 +569,7 @@ func polarisServiceRegister(polarisService *boot_config.PolarisService, apiServe // selfRegister 服务自注册 func selfRegister( host string, port uint32, protocol string, isolated bool, polarisService *boot_config.Service, hbInterval int) error { - server, err := service.GetOriginServer() + server, err := service.GetServer() if err != nil { return err } @@ -609,7 +609,7 @@ func selfRegister( Metadata: metadata, } - resp := server.CreateInstance(genContext(), req) + resp := server.RegisterInstance(genContext(), req) if api.CalcCode(resp) != 200 { // 如果self之前注册过,那么可以忽略 if resp.GetCode().GetValue() != api.ExistedResource { diff --git a/cache/api/types.go b/cache/api/types.go index 387e798bd..3020e2fb4 100644 --- a/cache/api/types.go +++ b/cache/api/types.go @@ -119,6 +119,10 @@ type ConfigEntry struct { // CacheManager type CacheManager interface { + // GetUpdateCacheInterval . + GetUpdateCacheInterval() time.Duration + // GetReportInterval . + GetReportInterval() time.Duration // GetCacher GetCacher(cacheIndex CacheIndex) Cache // RegisterCacher @@ -141,6 +145,8 @@ type CacheManager interface { FaultDetector() FaultDetectCache // ServiceContract 获取服务契约缓存 ServiceContract() ServiceContractCache + // LaneRule 泳道规则 + LaneRule() LaneCache // User Get user information cache information User() UserCache // AuthStrategy Get authentication cache information @@ -576,11 +582,13 @@ type BaseCache struct { lock sync.RWMutex // firstUpdate Whether the cache is loaded for the first time // this field can only make value on exec initialize/clean, and set it to false on exec update - firstUpdate bool - s store.Store - lastFetchTime int64 - lastMtimes map[string]time.Time - CacheMgr CacheManager + firstUpdate bool + s store.Store + lastFetchTime int64 + lastMtimes map[string]time.Time + CacheMgr CacheManager + reportMetrics func() + lastReportMetricsTime time.Time } func NewBaseCache(s store.Store, cacheMgr CacheManager) *BaseCache { @@ -593,6 +601,17 @@ func NewBaseCache(s store.Store, cacheMgr CacheManager) *BaseCache { return c } +func NewBaseCacheWithRepoerMetrics(s store.Store, cacheMgr CacheManager, reportMetrics func()) *BaseCache { + c := &BaseCache{ + s: s, + CacheMgr: cacheMgr, + reportMetrics: reportMetrics, + } + + c.initialize() + return c +} + func (bc *BaseCache) initialize() { bc.lock.Lock() defer bc.lock.Unlock() @@ -696,6 +715,12 @@ func (bc *BaseCache) DoCacheUpdate(name string, executor func() (map[string]time if total >= 0 { metrics.RecordCacheUpdateCost(time.Since(start), name, total) } + if bc.reportMetrics != nil { + if time.Since(bc.lastReportMetricsTime) >= bc.CacheMgr.GetReportInterval() { + bc.reportMetrics() + bc.lastReportMetricsTime = start + } + } bc.firstUpdate = false return nil } diff --git a/cache/auth/strategy.go b/cache/auth/strategy.go index a5d0f8e29..e71d624fe 100644 --- a/cache/auth/strategy.go +++ b/cache/auth/strategy.go @@ -100,12 +100,9 @@ func (sc *strategyCache) realUpdate() (map[string]time.Time, int64, error) { } lastMtimes, add, update, del := sc.setStrategys(strategies) - timeDiff := time.Since(start) - if timeDiff > time.Second { - log.Info("[Cache][AuthStrategy] get more auth strategy", - zap.Int("add", add), zap.Int("update", update), zap.Int("delete", del), - zap.Time("last", lastTime), zap.Duration("used", time.Since(start))) - } + log.Info("[Cache][AuthStrategy] get more auth strategy", + zap.Int("add", add), zap.Int("update", update), zap.Int("delete", del), + zap.Time("last", lastTime), zap.Duration("used", time.Since(start))) return lastMtimes, int64(len(strategies)), nil } @@ -409,23 +406,24 @@ func (sc *strategyCache) getStrategyDetails(uid string, gid string) []*model.Str strategyIds = sets.ToSlice() } + result := make([]*model.StrategyDetail, 0, 16) if len(strategyIds) > 0 { - result := make([]*model.StrategyDetail, 0, 16) for i := range strategyIds { strategy, ok := sc.strategys.Load(strategyIds[i]) if ok { result = append(result, strategy.StrategyDetail) } } - - return result } - - return nil + return result } // IsResourceLinkStrategy 校验 func (sc *strategyCache) IsResourceLinkStrategy(resType apisecurity.ResourceType, resId string) bool { + hasLinkRule := func(sets *utils.SyncSet[string]) bool { + return sets.Len() != 0 + } + switch resType { case apisecurity.ResourceType_Namespaces: val, ok := sc.namespace2Strategy.Load(resId) @@ -440,7 +438,3 @@ func (sc *strategyCache) IsResourceLinkStrategy(resType apisecurity.ResourceType return true } } - -func hasLinkRule(sets *utils.SyncSet[string]) bool { - return sets.Len() != 0 -} diff --git a/cache/auth/strategy_test.go b/cache/auth/strategy_test.go index 84a8911d6..8549f757f 100644 --- a/cache/auth/strategy_test.go +++ b/cache/auth/strategy_test.go @@ -20,6 +20,7 @@ package auth import ( "fmt" "testing" + "time" "github.com/golang/mock/gomock" apisecurity "github.com/polarismesh/specification/source/go/api/v1/security" @@ -32,7 +33,44 @@ import ( "github.com/polarismesh/polaris/store/mock" ) -func Test_strategyCache_IsResourceEditable_1(t *testing.T) { +func Test_strategyCache(t *testing.T) { + t.Run("get_policy", func(t *testing.T) { + ctrl := gomock.NewController(t) + mockCacheMgr := cachemock.NewMockCacheManager(ctrl) + mockStore := mock.NewMockStore(ctrl) + + t.Cleanup(func() { + ctrl.Finish() + }) + + userCache := NewUserCache(mockStore, mockCacheMgr) + strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) + + mockStore.EXPECT().GetUnixSecond(gomock.Any()).Return(time.Now().Unix(), nil) + mockStore.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).Return(buildStrategies(10), nil).AnyTimes() + mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() + + userCache.Initialize(map[string]interface{}{}) + strategyCache.Initialize(map[string]interface{}{}) + + _ = strategyCache.ForceSync() + _, _, _ = strategyCache.realUpdate() + + policies := strategyCache.GetStrategyDetailsByUID("user-1") + assert.True(t, len(policies) > 0, len(policies)) + + policies = strategyCache.GetStrategyDetailsByGroupID("group-1") + assert.True(t, len(policies) > 0, len(policies)) + + policies = strategyCache.GetStrategyDetailsByUID("fake-user-1") + assert.True(t, len(policies) == 0, len(policies)) + + policies = strategyCache.GetStrategyDetailsByGroupID("fake-group-1") + assert.True(t, len(policies) == 0, len(policies)) + }) + t.Run("资源没有关联任何策略", func(t *testing.T) { ctrl := gomock.NewController(t) mockCacheMgr := cachemock.NewMockCacheManager(ctrl) @@ -45,12 +83,17 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { userCache := NewUserCache(mockStore, mockCacheMgr) strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) + mockStore.EXPECT().GetUnixSecond(gomock.Any()).Return(time.Now().Unix(), nil) + mockStore.EXPECT().GetStrategyDetailsForCache(gomock.Any(), gomock.Any()).Return(buildStrategies(10), nil).AnyTimes() mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) - strategyCache.setStrategys(buildStrategies(10)) + _ = strategyCache.ForceSync() + _, _, _ = strategyCache.realUpdate() ret := strategyCache.IsResourceEditable(model.Principal{ PrincipalID: "user-1", @@ -58,6 +101,15 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { }, apisecurity.ResourceType_Namespaces, "namespace-1") assert.True(t, ret, "must be true") + + ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Namespaces, "namespace-1") + assert.True(t, ret, "must be true") + ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Services, "service-1") + assert.True(t, ret, "must be true") + ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_ConfigGroups, "config_group-1") + assert.True(t, ret, "must be true") + + strategyCache.Clear() }) t.Run("操作的目标资源关联了策略-自己在principal-user列表中", func(t *testing.T) { @@ -73,6 +125,8 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) @@ -119,6 +173,8 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) @@ -129,7 +185,18 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { PrincipalID: "user-20", PrincipalRole: model.PrincipalUser, }, apisecurity.ResourceType_Namespaces, "namespace-1") + assert.False(t, ret, "must be false") + ret = strategyCache.IsResourceEditable(model.Principal{ + PrincipalID: "user-20", + PrincipalRole: model.PrincipalUser, + }, apisecurity.ResourceType_Services, "service-1") + assert.False(t, ret, "must be false") + + ret = strategyCache.IsResourceEditable(model.Principal{ + PrincipalID: "user-20", + PrincipalRole: model.PrincipalUser, + }, apisecurity.ResourceType_ConfigGroups, "config_group-1") assert.False(t, ret, "must be false") }) @@ -146,6 +213,8 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) @@ -186,6 +255,8 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) @@ -276,6 +347,8 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { strategyCache := NewStrategyCache(mockStore, mockCacheMgr).(*strategyCache) mockCacheMgr.EXPECT().GetCacher(types.CacheUser).Return(userCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() userCache.Initialize(map[string]interface{}{}) strategyCache.Initialize(map[string]interface{}{}) @@ -373,6 +446,9 @@ func Test_strategyCache_IsResourceEditable_1(t *testing.T) { assert.True(t, ret, "must be true") + ret = strategyCache.IsResourceLinkStrategy(apisecurity.ResourceType_Namespaces, "namespace-1") + assert.True(t, ret, "must be true") + strategyDetail.Valid = false strategyCache.handlerPrincipalStrategy([]*model.StrategyDetail{strategyDetail}) @@ -419,6 +495,11 @@ func buildStrategies(num int) []*model.StrategyDetail { ResType: 1, ResID: fmt.Sprintf("service-%d", i+1), }, + { + StrategyID: fmt.Sprintf("rule-%d", i+1), + ResType: 2, + ResID: fmt.Sprintf("config_group-%d", i+1), + }, }, }) } diff --git a/cache/auth/user.go b/cache/auth/user.go index c56462009..673724c3d 100644 --- a/cache/auth/user.go +++ b/cache/auth/user.go @@ -111,20 +111,13 @@ func (uc *userCache) realUpdate() (map[string]time.Time, int64, error) { } lastMimes, refreshRet := uc.setUserAndGroups(users, groups) - timeDiff := time.Since(start) - if timeDiff > time.Second { - log.Info("[Cache][User] get more user", - zap.Int("add", refreshRet.userAdd), - zap.Int("update", refreshRet.userUpdate), - zap.Int("delete", refreshRet.userDel), - zap.Time("last", time.Unix(uc.lastUserMtime, 0)), zap.Duration("used", time.Since(start))) - - log.Info("[Cache][Group] get more group", - zap.Int("add", refreshRet.groupAdd), - zap.Int("update", refreshRet.groupUpdate), - zap.Int("delete", refreshRet.groupDel), - zap.Time("last", time.Unix(uc.lastGroupMtime, 0)), zap.Duration("used", time.Since(start))) - } + log.Info("[Cache][User] get more user and user_group", + zap.Int("user_add", refreshRet.userAdd), zap.Int("user_update", refreshRet.userUpdate), + zap.Int("user_delete", refreshRet.userDel), zap.Time("user_modify_last", time.Unix(uc.lastUserMtime, 0)), + zap.Int("group_add", refreshRet.groupAdd), zap.Int("group_update", refreshRet.groupUpdate), + zap.Int("group_delete", refreshRet.groupDel), zap.Time("group_modify_last", time.Unix(uc.lastGroupMtime, 0)), + zap.Duration("used", time.Since(start))) + return lastMimes, int64(len(users) + len(groups)), nil } diff --git a/cache/auth/user_test.go b/cache/auth/user_test.go index 1ad71baa4..8d7330023 100644 --- a/cache/auth/user_test.go +++ b/cache/auth/user_test.go @@ -124,6 +124,12 @@ func TestUserCache_UpdateNormal(t *testing.T) { users := genModelUsers(10) groups := genModelUserGroups(users) + admin := &model.User{ + ID: "admin-polaris", + Name: "admin-polaris", + Type: model.AdminUserRole, + Valid: true, + } t.Run("首次更新用户", func(t *testing.T) { copyUsers := make([]*model.User, 0, len(users)) @@ -133,6 +139,7 @@ func TestUserCache_UpdateNormal(t *testing.T) { copyUser := *users[i] copyUsers = append(copyUsers, ©User) } + copyUsers = append(copyUsers, admin) for i := range groups { copyGroup := *groups[i] @@ -165,6 +172,16 @@ func TestUserCache_UpdateNormal(t *testing.T) { assert.Equal(t, groups[1].ID, gid[0]) }) + t.Run("Is_owner", func(t *testing.T) { + assert.True(t, uc.IsOwner(users[0].ID), users[0].Type) + assert.False(t, uc.IsOwner(users[1].ID), users[1].Type) + assert.False(t, uc.IsOwner("fake-user-12312313")) + }) + + t.Run("Get_Admin", func(t *testing.T) { + assert.NotNil(t, uc.GetAdmin()) + }) + t.Run("部分用户删除", func(t *testing.T) { deleteCnt := 0 @@ -247,4 +264,17 @@ func TestUserCache_UpdateNormal(t *testing.T) { }) + t.Run("Abnormal_scene", func(t *testing.T) { + t.Run("group_id_empty", func(t *testing.T) { + assert.Nil(t, uc.GetGroup("")) + }) + + t.Run("user_id_empty", func(t *testing.T) { + assert.False(t, uc.IsUserInGroup("", "")) + assert.Nil(t, uc.GetUserByID("")) + assert.Nil(t, uc.GetUserLinkGroupIds("")) + }) + }) + + uc.Clear() } diff --git a/cache/cache.go b/cache/cache.go index d1617738e..f183fc669 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -37,6 +37,10 @@ const ( UpdateCacheInterval = 1 * time.Second ) +var ( + ReportInterval = 1 * time.Second +) + // CacheManager 名字服务缓存 type CacheManager struct { storage store.Store @@ -165,6 +169,11 @@ func (nc *CacheManager) GetUpdateCacheInterval() time.Duration { return UpdateCacheInterval } +// GetReportInterval 获取当前cache的更新间隔 +func (nc *CacheManager) GetReportInterval() time.Duration { + return ReportInterval +} + // Service 获取Service缓存信息 func (nc *CacheManager) Service() types.ServiceCache { return nc.caches[types.CacheService].(types.ServiceCache) diff --git a/cache/config.go b/cache/config.go index 11c30c64e..989987d38 100644 --- a/cache/config.go +++ b/cache/config.go @@ -23,6 +23,8 @@ import "time" type Config struct { // DiffTime 设置拉取时间范围, [T1 - abs(DiffTime), T1] DiffTime time.Duration `yaml:"diffTime"` + // ReportInterval 监控数据上报周期 + ReportInterval time.Duration `yaml:"reportInterval"` } var ( diff --git a/cache/config/config_file.go b/cache/config/config_file.go index de0d7fc86..b6ff0bdde 100644 --- a/cache/config/config_file.go +++ b/cache/config/config_file.go @@ -65,11 +65,11 @@ type fileCache struct { // NewConfigFileCache 创建文件缓存 func NewConfigFileCache(storage store.Store, cacheMgr types.CacheManager) types.ConfigFileCache { - cache := &fileCache{ - BaseCache: types.NewBaseCache(storage, cacheMgr), + fc := &fileCache{ storage: storage, } - return cache + fc.BaseCache = types.NewBaseCacheWithRepoerMetrics(storage, cacheMgr, fc.reportMetricsInfo) + return fc } // Initialize diff --git a/cache/config/config_file_metrics.go b/cache/config/config_file_metrics.go index a23ef8b23..3b938c3cd 100644 --- a/cache/config/config_file_metrics.go +++ b/cache/config/config_file_metrics.go @@ -53,7 +53,7 @@ func (fc *fileCache) reportMetricsInfo() { tmpGroup[ns][group] = struct{}{} } } - cleanExpireConfigFileMetricLabel(fc.preMetricsFiles.Load(), tmpGroup) + _, _ = cleanExpireConfigFileMetricLabel(fc.preMetricsFiles.Load(), tmpGroup) fc.preMetricsFiles.Store(tmpGroup) for ns, groups := range configFiles { @@ -85,26 +85,26 @@ func (fc *fileCache) reportMetricsInfo() { plugin.GetStatis().ReportConfigMetrics(metricValues...) } -func cleanExpireConfigFileMetricLabel(pre, curr map[string]map[string]struct{}) { +func cleanExpireConfigFileMetricLabel(pre, curr map[string]map[string]struct{}) (map[string]struct{}, map[string]map[string]struct{}) { if len(pre) == 0 { - return + return map[string]struct{}{}, map[string]map[string]struct{}{} } var ( - removeNs = map[string]struct{}{} - remove = map[string]map[string]struct{}{} + removeNs = map[string]struct{}{} + removeGroups = map[string]map[string]struct{}{} ) for ns, groups := range pre { if _, ok := curr[ns]; !ok { removeNs[ns] = struct{}{} } - if _, ok := remove[ns]; !ok { - remove[ns] = map[string]struct{}{} + if _, ok := removeGroups[ns]; !ok { + removeGroups[ns] = map[string]struct{}{} } for group := range groups { if _, ok := curr[ns][group]; !ok { - remove[ns][group] = struct{}{} + removeGroups[ns][group] = struct{}{} } } } @@ -115,7 +115,7 @@ func cleanExpireConfigFileMetricLabel(pre, curr map[string]map[string]struct{}) }) } - for ns, groups := range remove { + for ns, groups := range removeGroups { for group := range groups { metrics.GetConfigFileTotal().Delete(prometheus.Labels{ metrics.LabelNamespace: ns, @@ -131,5 +131,5 @@ func cleanExpireConfigFileMetricLabel(pre, curr map[string]map[string]struct{}) }) } } - + return removeNs, removeGroups } diff --git a/cache/config/config_file_metrics_test.go b/cache/config/config_file_metrics_test.go new file mode 100644 index 000000000..e7470ffe7 --- /dev/null +++ b/cache/config/config_file_metrics_test.go @@ -0,0 +1,75 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package config + +import ( + "reflect" + "testing" + + "github.com/polarismesh/polaris/common/metrics" +) + +func Test_cleanExpireConfigFileMetricLabel(t *testing.T) { + metrics.InitMetrics() + type args struct { + pre map[string]map[string]struct{} + curr map[string]map[string]struct{} + } + tests := []struct { + name string + args args + want map[string]struct{} + want1 map[string]map[string]struct{} + }{ + { + name: "01", + args: args{ + pre: map[string]map[string]struct{}{ + "ns-1": { + "group-1": {}, + }, + }, + curr: map[string]map[string]struct{}{ + "ns-2": { + "group-2": {}, + }, + }, + }, + want: map[string]struct{}{ + "ns-1": {}, + }, + want1: map[string]map[string]struct{}{ + "ns-1": { + "group-1": {}, + }, + }, + }, + + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := cleanExpireConfigFileMetricLabel(tt.args.pre, tt.args.curr) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("cleanExpireConfigFileMetricLabel() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("cleanExpireConfigFileMetricLabel() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/cache/config/config_group.go b/cache/config/config_group.go index 22f0d28c8..04db33c4c 100644 --- a/cache/config/config_group.go +++ b/cache/config/config_group.go @@ -46,11 +46,11 @@ type configGroupCache struct { // NewConfigGroupCache 创建文件缓存 func NewConfigGroupCache(storage store.Store, cacheMgr types.CacheManager) types.ConfigGroupCache { - cache := &configGroupCache{ - BaseCache: types.NewBaseCache(storage, cacheMgr), + gc := &configGroupCache{ storage: storage, } - return cache + gc.BaseCache = types.NewBaseCacheWithRepoerMetrics(storage, cacheMgr, gc.reportMetricsInfo) + return gc } // Initialize @@ -145,8 +145,6 @@ func (fc *configGroupCache) postProcessUpdatedGroups(affect map[string]struct{}) continue } count := nsBucket.Len() - fc.reportMetricsInfo(ns, count) - revisions := make([]string, 0, count) nsBucket.Range(func(key string, val *model.ConfigFileGroup) { revisions = append(revisions, val.Revision) diff --git a/cache/config/config_group_metrics.go b/cache/config/config_group_metrics.go index c669c31f7..ea2f2624a 100644 --- a/cache/config/config_group_metrics.go +++ b/cache/config/config_group_metrics.go @@ -19,17 +19,23 @@ package config import ( "github.com/polarismesh/polaris/common/metrics" + "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/plugin" ) -func (fc *configGroupCache) reportMetricsInfo(ns string, count int) { - reportValue := metrics.ConfigMetrics{ - Type: metrics.ConfigGroupMetric, - Total: int64(count), - Release: 0, - Labels: map[string]string{ - metrics.LabelNamespace: ns, - }, - } - plugin.GetStatis().ReportConfigMetrics(reportValue) +func (fc *configGroupCache) reportMetricsInfo() { + fc.name2groups.Range(func(ns string, val *utils.SyncMap[string, *model.ConfigFileGroup]) { + count := val.Len() + reportValue := metrics.ConfigMetrics{ + Type: metrics.ConfigGroupMetric, + Total: int64(count), + Release: 0, + Labels: map[string]string{ + metrics.LabelNamespace: ns, + }, + } + plugin.GetStatis().ReportConfigMetrics(reportValue) + }) + } diff --git a/cache/mock/cache_mock.go b/cache/mock/cache_mock.go index 81baae2b1..ab05d026c 100644 --- a/cache/mock/cache_mock.go +++ b/cache/mock/cache_mock.go @@ -7,6 +7,7 @@ package mock import ( context "context" reflect "reflect" + time "time" gomock "github.com/golang/mock/gomock" api "github.com/polarismesh/polaris/cache/api" @@ -246,6 +247,34 @@ func (mr *MockCacheManagerMockRecorder) GetCacher(cacheIndex interface{}) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCacher", reflect.TypeOf((*MockCacheManager)(nil).GetCacher), cacheIndex) } +// GetReportInterval mocks base method. +func (m *MockCacheManager) GetReportInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetReportInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetReportInterval indicates an expected call of GetReportInterval. +func (mr *MockCacheManagerMockRecorder) GetReportInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetReportInterval", reflect.TypeOf((*MockCacheManager)(nil).GetReportInterval)) +} + +// GetUpdateCacheInterval mocks base method. +func (m *MockCacheManager) GetUpdateCacheInterval() time.Duration { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUpdateCacheInterval") + ret0, _ := ret[0].(time.Duration) + return ret0 +} + +// GetUpdateCacheInterval indicates an expected call of GetUpdateCacheInterval. +func (mr *MockCacheManagerMockRecorder) GetUpdateCacheInterval() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUpdateCacheInterval", reflect.TypeOf((*MockCacheManager)(nil).GetUpdateCacheInterval)) +} + // Gray mocks base method. func (m *MockCacheManager) Gray() api.GrayCache { m.ctrl.T.Helper() @@ -274,6 +303,20 @@ func (mr *MockCacheManagerMockRecorder) Instance() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Instance", reflect.TypeOf((*MockCacheManager)(nil).Instance)) } +// LaneRule mocks base method. +func (m *MockCacheManager) LaneRule() api.LaneCache { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LaneRule") + ret0, _ := ret[0].(api.LaneCache) + return ret0 +} + +// LaneRule indicates an expected call of LaneRule. +func (mr *MockCacheManagerMockRecorder) LaneRule() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LaneRule", reflect.TypeOf((*MockCacheManager)(nil).LaneRule)) +} + // Namespace mocks base method. func (m *MockCacheManager) Namespace() api.NamespaceCache { m.ctrl.T.Helper() diff --git a/cache/service/instance.go b/cache/service/instance.go index 491b833b3..80c60a5f0 100644 --- a/cache/service/instance.go +++ b/cache/service/instance.go @@ -63,11 +63,13 @@ type instanceCache struct { // NewInstanceCache 新建一个instanceCache func NewInstanceCache(storage store.Store, cacheMgr types.CacheManager) types.InstanceCache { - return &instanceCache{ - BaseCache: types.NewBaseCache(storage, cacheMgr), + ic := &instanceCache{ storage: storage, singleFlight: new(singleflight.Group), } + + ic.BaseCache = types.NewBaseCacheWithRepoerMetrics(storage, cacheMgr, ic.reportMetricsInfo) + return ic } // Initialize 初始化函数 @@ -161,7 +163,6 @@ func (ic *instanceCache) realUpdate() (map[string]time.Time, int64, error) { for i := range instanceChangeEvents { _ = eventhub.Publish(eventhub.CacheInstanceEventTopic, instanceChangeEvents[i]) } - ic.reportMetricsInfo() }() if err := tx.CreateReadView(); err != nil { diff --git a/cache/service/instance_test.go b/cache/service/instance_test.go index e1522fcf8..0b6ecf8b5 100644 --- a/cache/service/instance_test.go +++ b/cache/service/instance_test.go @@ -48,6 +48,8 @@ func newTestInstanceCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *i mockCacheMgr.EXPECT().GetCacher(types.CacheService).Return(mockSvcCache).AnyTimes() mockCacheMgr.EXPECT().GetCacher(types.CacheInstance).Return(mockInstCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() mockTx := mock.NewMockTx(ctl) mockTx.EXPECT().Commit().Return(nil).AnyTimes() diff --git a/cache/service/ratelimit_config_test.go b/cache/service/ratelimit_config_test.go index 698e6e22c..bc0b62a9a 100644 --- a/cache/service/ratelimit_config_test.go +++ b/cache/service/ratelimit_config_test.go @@ -51,6 +51,8 @@ func newTestRateLimitCache(t *testing.T) (*gomock.Controller, *mock.MockStore, * mockCacheMgr.EXPECT().GetCacher(types.CacheService).Return(mockSvcCache).AnyTimes() mockCacheMgr.EXPECT().GetCacher(types.CacheInstance).Return(mockInstCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) var opt map[string]interface{} diff --git a/cache/service/router_rule.go b/cache/service/router_rule.go index 1953674a5..2d52d62f9 100644 --- a/cache/service/router_rule.go +++ b/cache/service/router_rule.go @@ -117,7 +117,7 @@ func (rc *routingConfigCache) Name() string { } func (rc *routingConfigCache) ListRouterRule(service, namespace string) []*model.ExtendRouterConfig { - routerRules := rc.bucket.listEnableRules(service, namespace) + routerRules := rc.bucket.listEnableRules(service, namespace, true) ret := make([]*model.ExtendRouterConfig, 0, len(routerRules)) for level := range routerRules { items := routerRules[level] @@ -132,7 +132,7 @@ func (rc *routingConfigCache) GetRouterConfigV2(id, service, namespace string) ( return nil, nil } - routerRules := rc.bucket.listEnableRules(service, namespace) + routerRules := rc.bucket.listEnableRules(service, namespace, true) revisions := make([]string, 0, 8) rulesV2 := make([]*apitraffic.RouteRule, 0, len(routerRules)) for level := range routerRules { @@ -167,7 +167,7 @@ func (rc *routingConfigCache) GetRouterConfig(id, service, namespace string) (*a return nil, nil } - routerRules := rc.bucket.listEnableRules(service, namespace) + routerRules := rc.bucket.listEnableRules(service, namespace, false) inBounds, outBounds, revisions := rc.convertV2toV1(routerRules, service, namespace) revision, err := types.CompositeComputeRevision(revisions) if err != nil { diff --git a/cache/service/router_rule_bucket.go b/cache/service/router_rule_bucket.go index c0d9b49fa..7f5f19376 100644 --- a/cache/service/router_rule_bucket.go +++ b/cache/service/router_rule_bucket.go @@ -288,7 +288,7 @@ func (b *routeRuleBucket) size() int { // listEnableRules Inquire the routing rules of the V2 version through the service name, // and perform some filtering according to the Predicate -func (b *routeRuleBucket) listEnableRules(service, namespace string) map[routingLevel][]*model.ExtendRouterConfig { +func (b *routeRuleBucket) listEnableRules(service, namespace string, enableFullMatch bool) map[routingLevel][]*model.ExtendRouterConfig { ret := make(map[routingLevel][]*model.ExtendRouterConfig) tmpRecord := map[string]struct{}{} @@ -338,11 +338,13 @@ func (b *routeRuleBucket) listEnableRules(service, namespace string) map[routing level2 = append(level2, handler(b.level2Rules[inBound][namespace], inBound)...) ret[level2RoutingV2] = level2 - // Query Level3 level routing-v2 rules - level3 := make([]*model.ExtendRouterConfig, 0, 4) - level3 = append(level3, handler(b.level3Rules[outBound], outBound)...) - level3 = append(level3, handler(b.level3Rules[inBound], inBound)...) - ret[level3RoutingV2] = level3 + if enableFullMatch { + // Query Level3 level routing-v2 rules + level3 := make([]*model.ExtendRouterConfig, 0, 4) + level3 = append(level3, handler(b.level3Rules[outBound], outBound)...) + level3 = append(level3, handler(b.level3Rules[inBound], inBound)...) + ret[level3RoutingV2] = level3 + } return ret } diff --git a/cache/service/service_test.go b/cache/service/service_test.go index c7cfd8b37..087c8555a 100644 --- a/cache/service/service_test.go +++ b/cache/service/service_test.go @@ -50,6 +50,8 @@ func newTestServiceCache(t *testing.T) (*gomock.Controller, *mock.MockStore, *se mockCacheMgr.EXPECT().GetCacher(types.CacheService).Return(mockSvcCache).AnyTimes() mockCacheMgr.EXPECT().GetCacher(types.CacheInstance).Return(mockInstCache).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() mockTx := mock.NewMockTx(ctl) mockTx.EXPECT().Commit().Return(nil).AnyTimes() @@ -485,6 +487,9 @@ func TestRevisionWorker(t *testing.T) { ctl := gomock.NewController(t) storage := mock.NewMockStore(ctl) mockCacheMgr := cachemock.NewMockCacheManager(ctl) + + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() storage.EXPECT().GetUnixSecond(gomock.Any()).AnyTimes().Return(time.Now().Unix(), nil) defer ctl.Finish() @@ -600,6 +605,8 @@ func Test_serviceCache_GetVisibleServicesInOtherNamespace(t *testing.T) { ctl := gomock.NewController(t) storage := mock.NewMockStore(ctl) mockCacheMgr := cachemock.NewMockCacheManager(ctl) + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() defer ctl.Finish() t.Run("服务可见性查询判断", func(t *testing.T) { diff --git a/common/conn/limit/config_test.go b/common/conn/limit/config_test.go index 12a050401..490355f8c 100644 --- a/common/conn/limit/config_test.go +++ b/common/conn/limit/config_test.go @@ -21,12 +21,12 @@ import ( "testing" "time" - . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" ) // TestParseConnLimitConfig 可以正常解析配置测试 func TestParseConnLimitConfig(t *testing.T) { - Convey("可以正常解析配置", t, func() { + t.Run("可以正常解析配置", func(t *testing.T) { options := map[interface{}]interface{}{ "openConnLimit": true, "maxConnPerHost": 16, @@ -35,11 +35,11 @@ func TestParseConnLimitConfig(t *testing.T) { "readTimeout": "120s", } config, err := ParseConnLimitConfig(options) - So(err, ShouldBeNil) - So(config.OpenConnLimit, ShouldBeTrue) - So(config.MaxConnPerHost, ShouldEqual, 16) - So(config.MaxConnLimit, ShouldEqual, 128) - So(config.WhiteList, ShouldEqual, "127.0.0.1,127.0.0.2,127.0.0.3") - So(config.ReadTimeout, ShouldEqual, time.Second*120) + assert.Nil(t, err) + assert.True(t, config.OpenConnLimit) + assert.Equal(t, config.MaxConnPerHost, 16) + assert.Equal(t, config.MaxConnLimit, 128) + assert.Equal(t, config.WhiteList, "127.0.0.1,127.0.0.2,127.0.0.3") + assert.Equal(t, config.ReadTimeout, time.Second*120) }) } diff --git a/common/utils/common.go b/common/utils/common.go index 43bac20b2..fcf79b6e6 100644 --- a/common/utils/common.go +++ b/common/utils/common.go @@ -86,6 +86,11 @@ const ( MaxDbCircuitbreakerComment = 1024 MaxDbCircuitbreakerOwner = 1024 MaxDbCircuitbreakerVersion = 32 + + MaxPlatformIDLength = 32 + MaxPlatformNameLength = 128 + MaxPlatformDomainLength = 1024 + MaxPlatformQPS = 65535 ) var resourceNameRE = regexp.MustCompile("^[0-9A-Za-z-./:_]+$") diff --git a/config/client.go b/config/client.go index f960b45de..5636b5cb2 100644 --- a/config/client.go +++ b/config/client.go @@ -92,11 +92,11 @@ func (s *Server) LongPullWatchFile(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest) (WatchCallback, error) { watchFiles := req.GetWatchFiles() - tmpWatchCtx := BuildTimeoutWatchCtx(ctx, 0)("", s.watchCenter.MatchBetaReleaseFile) + tmpWatchCtx := BuildTimeoutWatchCtx(ctx, req, 0)("", s.watchCenter.MatchBetaReleaseFile) for _, file := range watchFiles { tmpWatchCtx.AppendInterest(file) } - if quickResp := s.watchCenter.checkQuickResponseClient(tmpWatchCtx); quickResp != nil { + if quickResp := s.watchCenter.CheckQuickResponseClient(tmpWatchCtx); quickResp != nil { _ = tmpWatchCtx.Close() return func() *apiconfig.ConfigClientResponse { return quickResp @@ -110,16 +110,20 @@ func (s *Server) LongPullWatchFile(ctx context.Context, // 3. 监听配置变更,hold 请求 30s,30s 内如果有配置发布,则响应请求 clientId := utils.ParseClientAddress(ctx) + "@" + utils.NewUUID()[0:8] - watchCtx := s.WatchCenter().AddWatcher(clientId, watchFiles, BuildTimeoutWatchCtx(ctx, watchTimeOut)) + watchCtx := s.WatchCenter().AddWatcher(clientId, watchFiles, BuildTimeoutWatchCtx(ctx, req, watchTimeOut)) return func() *apiconfig.ConfigClientResponse { return (watchCtx.(*LongPollWatchContext)).GetNotifieResult() }, nil } -func BuildTimeoutWatchCtx(ctx context.Context, watchTimeOut time.Duration) WatchContextFactory { +func BuildTimeoutWatchCtx(ctx context.Context, req *apiconfig.ClientWatchConfigFileRequest, + watchTimeOut time.Duration) WatchContextFactory { labels := map[string]string{ model.ClientLabel_IP: utils.ParseClientIP(ctx), } + if len(req.GetClientIp().GetValue()) != 0 { + labels[model.ClientLabel_IP] = req.GetClientIp().GetValue() + } return func(clientId string, matcher BetaReleaseMatcher) WatchContext { watchCtx := &LongPollWatchContext{ clientId: clientId, diff --git a/config/client_test.go b/config/client_test.go index 37df0933e..c95c99b48 100644 --- a/config/client_test.go +++ b/config/client_test.go @@ -367,7 +367,104 @@ func TestWatchConfigFileAtFirstPublish(t *testing.T) { // 创建并发布配置文件 configFile := assembleConfigFile() - t.Run("第一次订阅发布", func(t *testing.T) { + t.Run("00_QuickCheck", func(t *testing.T) { + curConfigFile := assembleConfigFile() + curConfigFile.Namespace = utils.NewStringValue("QuickCheck") + + rsp := testSuit.ConfigServer().CreateConfigFile(testSuit.DefaultCtx, curConfigFile) + t.Log("create config file success") + assert.Equal(t, api.ExecuteSuccess, rsp.Code.GetValue(), rsp.GetInfo().GetValue()) + + watchCenter := testSuit.OriginConfigServer().WatchCenter() + + t.Run("01_file_not_exist", func(t *testing.T) { + _ = testSuit.DiscoverServer().Cache().ConfigFile().Update() + + tmpWatchCtx := config.BuildTimeoutWatchCtx(context.Background(), + &apiconfig.ClientWatchConfigFileRequest{ + WatchFiles: []*apiconfig.ClientConfigFileInfo{ + { + Namespace: curConfigFile.Namespace, + Group: curConfigFile.Group, + FileName: curConfigFile.Name, + }, + }, + }, 0)(utils.NewUUID(), watchCenter.MatchBetaReleaseFile) + + rsp := watchCenter.CheckQuickResponseClient(tmpWatchCtx) + assert.Nil(t, rsp, rsp) + }) + + t.Run("02_normal", func(t *testing.T) { + // 发布一个正常的配置文件 + rsp2 := testSuit.ConfigServer().PublishConfigFile(testSuit.DefaultCtx, assembleConfigFileRelease(curConfigFile)) + t.Log("publish config file success") + assert.Equal(t, api.ExecuteSuccess, rsp2.Code.GetValue(), rsp2.GetInfo().GetValue()) + + _ = testSuit.DiscoverServer().Cache().ConfigFile().Update() + + req := &apiconfig.ClientWatchConfigFileRequest{ + WatchFiles: []*apiconfig.ClientConfigFileInfo{ + { + Namespace: curConfigFile.Namespace, + Group: curConfigFile.Group, + FileName: curConfigFile.Name, + }, + }, + } + tmpWatchCtx := config.BuildTimeoutWatchCtx(context.Background(), + req, 0)(utils.NewUUID(), watchCenter.MatchBetaReleaseFile) + + for i := range req.WatchFiles { + tmpWatchCtx.AppendInterest(req.WatchFiles[i]) + } + rsp := watchCenter.CheckQuickResponseClient(tmpWatchCtx) + assert.NotNil(t, rsp, rsp) + assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + + t.Run("03_gray", func(t *testing.T) { + // 发布一个灰度配置文件 + curConfigFile.Content = utils.NewStringValue("gray polaris test") + grayRelease := assembleConfigFileRelease(curConfigFile) + grayRelease.ReleaseType = wrapperspb.String("gray") + grayRelease.BetaLabels = []*apimodel.ClientLabel{ + &apimodel.ClientLabel{ + Key: "CLIENT_IP", + Value: &apimodel.MatchString{ + Type: apimodel.MatchString_EXACT, + Value: utils.NewStringValue("172.0.0.1"), + ValueType: apimodel.MatchString_TEXT, + }, + }, + } + rsp2 := testSuit.ConfigServer().PublishConfigFile(testSuit.DefaultCtx, grayRelease) + t.Log("publish config file success") + assert.Equal(t, api.ExecuteSuccess, rsp2.Code.GetValue(), rsp2.GetInfo().GetValue()) + + req := &apiconfig.ClientWatchConfigFileRequest{ + ClientIp: wrapperspb.String("172.0.0.1"), + WatchFiles: []*apiconfig.ClientConfigFileInfo{ + { + Namespace: curConfigFile.Namespace, + Group: curConfigFile.Group, + FileName: curConfigFile.Name, + }, + }, + } + tmpWatchCtx := config.BuildTimeoutWatchCtx(context.Background(), + req, 0)(utils.NewUUID(), watchCenter.MatchBetaReleaseFile) + + for i := range req.WatchFiles { + tmpWatchCtx.AppendInterest(req.WatchFiles[i]) + } + rsp := watchCenter.CheckQuickResponseClient(tmpWatchCtx) + assert.NotNil(t, rsp, rsp) + assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + }) + + t.Run("01_first_watch", func(t *testing.T) { watchConfigFiles := assembleDefaultClientConfigFile(0) clientId := "TestWatchConfigFileAtFirstPublish-first" @@ -376,7 +473,7 @@ func TestWatchConfigFileAtFirstPublish(t *testing.T) { }() watchCtx := testSuit.OriginConfigServer().WatchCenter().AddWatcher(clientId, watchConfigFiles, - config.BuildTimeoutWatchCtx(context.Background(), 30*time.Second)) + config.BuildTimeoutWatchCtx(context.Background(), &apiconfig.ClientWatchConfigFileRequest{}, 30*time.Second)) assert.NotNil(t, watchCtx) rsp := testSuit.ConfigServer().CreateConfigFile(testSuit.DefaultCtx, configFile) @@ -405,14 +502,14 @@ func TestWatchConfigFileAtFirstPublish(t *testing.T) { assert.Equal(t, uint64(1), receivedVersion) }) - t.Run("第二次订阅发布", func(t *testing.T) { + t.Run("02_second_watch", func(t *testing.T) { // 版本号由于发布过一次,所以是1 watchConfigFiles := assembleDefaultClientConfigFile(1) clientId := "TestWatchConfigFileAtFirstPublish-second" watchCtx := testSuit.OriginConfigServer().WatchCenter().AddWatcher(clientId, watchConfigFiles, - config.BuildTimeoutWatchCtx(context.Background(), 30*time.Second)) + config.BuildTimeoutWatchCtx(context.Background(), &apiconfig.ClientWatchConfigFileRequest{}, 30*time.Second)) assert.NotNil(t, watchCtx) rsp3 := testSuit.ConfigServer().PublishConfigFile(testSuit.DefaultCtx, assembleConfigFileRelease(configFile)) @@ -430,6 +527,24 @@ func TestWatchConfigFileAtFirstPublish(t *testing.T) { // 为了避免影响其它 case,删除订阅 testSuit.OriginConfigServer().WatchCenter().RemoveWatcher(clientId, watchConfigFiles) }) + + t.Run("03_clean_invalid_client", func(t *testing.T) { + watchConfigFiles := assembleDefaultClientConfigFile(1) + for i := range watchConfigFiles { + watchConfigFiles[i].Namespace = utils.NewStringValue("03_clean_invalid_client") + } + + // watchCtx 默认为 1s 超时 + watchCtx := testSuit.OriginConfigServer().WatchCenter().AddWatcher(utils.NewUUID(), watchConfigFiles, + config.BuildTimeoutWatchCtx(context.Background(), &apiconfig.ClientWatchConfigFileRequest{}, time.Second)) + assert.NotNil(t, watchCtx) + + time.Sleep(10 * time.Second) + + // + ret := watchCtx.(*config.LongPollWatchContext).GetNotifieResult() + assert.Equal(t, uint32(apimodel.Code_DataNoChange), ret.GetCode().GetValue()) + }) } // Test10000ClientWatchConfigFile 测试 10000 个客户端同时监听配置变更,配置发布所有客户端都收到通知 @@ -446,7 +561,7 @@ func TestManyClientWatchConfigFile(t *testing.T) { received.Store(clientId, false) receivedVersion.Store(clientId, uint64(0)) watchCtx := testSuit.OriginConfigServer().WatchCenter().AddWatcher(clientId, watchConfigFiles, - config.BuildTimeoutWatchCtx(context.Background(), 30*time.Second)) + config.BuildTimeoutWatchCtx(context.Background(), &apiconfig.ClientWatchConfigFileRequest{}, 30*time.Second)) assert.NotNil(t, watchCtx) go func() { notifyRsp := (watchCtx.(*config.LongPollWatchContext)).GetNotifieResult() @@ -523,10 +638,12 @@ func TestDeleteConfigFile(t *testing.T) { // 删除配置文件 t.Log("remove config file") - rsp3 := testSuit.ConfigServer().DeleteConfigFile(testSuit.DefaultCtx, &apiconfig.ConfigFile{ - Namespace: utils.NewStringValue(newMockNs), - Group: utils.NewStringValue(testGroup), - Name: utils.NewStringValue(testFile), + rsp3 := testSuit.ConfigServer().BatchDeleteConfigFile(testSuit.DefaultCtx, []*apiconfig.ConfigFile{ + &apiconfig.ConfigFile{ + Namespace: utils.NewStringValue(newMockNs), + Group: utils.NewStringValue(testGroup), + Name: utils.NewStringValue(testFile), + }, }) assert.Equal(t, api.ExecuteSuccess, rsp3.Code.GetValue()) _ = testSuit.CacheMgr().TestUpdate() @@ -573,6 +690,11 @@ func TestServer_GetConfigFileNamesWithCache(t *testing.T) { } }) + _ = testSuit.OriginConfigServer().CacheManager() + _ = testSuit.OriginConfigServer().GroupCache() + _ = testSuit.OriginConfigServer().FileCache() + _ = testSuit.OriginConfigServer().CryptoManager() + t.Run("bad-request", func(t *testing.T) { rsp := testSuit.ConfigServer().GetConfigFileNamesWithCache(testSuit.DefaultCtx, &apiconfig.ConfigFileGroupRequest{ ConfigFileGroup: &apiconfig.ConfigFileGroup{ diff --git a/config/config_file.go b/config/config_file.go index 3f91daa44..b4d4a8628 100644 --- a/config/config_file.go +++ b/config/config_file.go @@ -528,16 +528,3 @@ func configFileRecordEntry(ctx context.Context, req *apiconfig.ConfigFile, return entry } - -func checkReadFileParameter(req *apiconfig.ConfigFile) *apiconfig.ConfigResponse { - if req.GetNamespace().GetValue() == "" { - return api.NewConfigResponse(apimodel.Code_InvalidNamespaceName) - } - if req.GetGroup().GetValue() == "" { - return api.NewConfigResponse(apimodel.Code_InvalidConfigFileGroupName) - } - if req.GetName().GetValue() == "" { - return api.NewConfigResponse(apimodel.Code_InvalidConfigFileName) - } - return nil -} diff --git a/config/config_file_group.go b/config/config_file_group.go index 6077366d7..e81f57aae 100644 --- a/config/config_file_group.go +++ b/config/config_file_group.go @@ -100,7 +100,7 @@ func (s *Server) UpdateConfigFileGroup(ctx context.Context, req *apiconfig.Confi updateData := model.ToConfigGroupStore(req) updateData.ModifyBy = utils.ParseOperator(ctx) - updateData, needUpdate := s.updateGroupAttribute(saveData, updateData) + updateData, needUpdate := s.UpdateGroupAttribute(saveData, updateData) if !needUpdate { return api.NewConfigResponse(apimodel.Code_NoNeedUpdate) } @@ -122,7 +122,7 @@ func (s *Server) UpdateConfigFileGroup(ctx context.Context, req *apiconfig.Confi return api.NewConfigResponse(apimodel.Code_ExecuteSuccess) } -func (s *Server) updateGroupAttribute(saveData, updateData *model.ConfigFileGroup) (*model.ConfigFileGroup, bool) { +func (s *Server) UpdateGroupAttribute(saveData, updateData *model.ConfigFileGroup) (*model.ConfigFileGroup, bool) { needUpdate := false if saveData.Comment != updateData.Comment { needUpdate = true diff --git a/config/config_file_group_test.go b/config/config_file_group_test.go index 0ca05373a..5be79ee96 100644 --- a/config/config_file_group_test.go +++ b/config/config_file_group_test.go @@ -19,14 +19,16 @@ package config_test import ( "fmt" + "reflect" "testing" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/config" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/wrapperspb" - - api "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/utils" ) var ( @@ -220,3 +222,46 @@ func TestConfigFileGroupCRUD(t *testing.T) { assert.Equal(t, randomGroupSize, rsp2.Total.GetValue()) }) } + +func TestServer_UpdateGroupAttribute(t *testing.T) { + type args struct { + saveData *model.ConfigFileGroup + updateData *model.ConfigFileGroup + } + tests := []struct { + name string + args args + want *model.ConfigFileGroup + want1 bool + }{ + { + name: "01", + args: args{ + saveData: &model.ConfigFileGroup{ + Comment: "test", + Business: "test", + Department: "test", + }, + updateData: &model.ConfigFileGroup{ + Comment: "test-1", + }, + }, + want: &model.ConfigFileGroup{ + Comment: "test-1", + }, + want1: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s := &config.Server{} + got, got1 := s.UpdateGroupAttribute(tt.args.saveData, tt.args.updateData) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Server.UpdateGroupAttribute() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("Server.UpdateGroupAttribute() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} diff --git a/config/config_file_release_test.go b/config/config_file_release_test.go index 17eee06c4..34995cd79 100644 --- a/config/config_file_release_test.go +++ b/config/config_file_release_test.go @@ -226,6 +226,16 @@ func Test_PublishConfigFile(t *testing.T) { // 客户端读取数据正常 _ = testSuit.CacheMgr().TestUpdate() + cacheData := testSuit.CacheMgr().ConfigFile().GetRelease(model.ConfigFileReleaseKey{ + Namespace: mockNamespace + "same-v1", + Group: mockGroup + "same-v1", + FileName: mockFileName + "same-v1", + Name: mockReleaseName + "same-v1", + ReleaseType: model.ReleaseTypeFull, + }) + assert.NotNil(t, cacheData) + assert.Equal(t, mockContent+"same-v1", cacheData.Content) + clientRsp := testSuit.ConfigServer().GetConfigFileWithCache(testSuit.DefaultCtx, &config_manage.ClientConfigFileInfo{ Namespace: utils.NewStringValue(mockNamespace + "same-v1"), Group: utils.NewStringValue(mockGroup + "same-v1"), diff --git a/config/server.go b/config/server.go index 1f1db53d4..ee4476347 100644 --- a/config/server.go +++ b/config/server.go @@ -330,7 +330,7 @@ func (cc *ConfigChains) AfterGetFileHistory(ctx context.Context, func GetChainOrder() []string { return []string{ - "paramcheck", "auth", + "paramcheck", } } diff --git a/config/server_test.go b/config/server_test.go index b923ad656..ae1568c51 100644 --- a/config/server_test.go +++ b/config/server_test.go @@ -20,6 +20,7 @@ package config import ( "context" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -47,6 +48,8 @@ func Test_Initialize(t *testing.T) { cacheMgr.EXPECT().ConfigFile().Return(nil).AnyTimes() cacheMgr.EXPECT().Gray().Return(nil).AnyTimes() cacheMgr.EXPECT().ConfigGroup().Return(nil).AnyTimes() + cacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + cacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() _, _, err := auth.TestInitialize(context.Background(), &auth.Config{}, mockStore, cacheMgr) assert.NoError(t, err) diff --git a/config/utils_test.go b/config/utils_test.go index 4b85b1fb2..645efc966 100644 --- a/config/utils_test.go +++ b/config/utils_test.go @@ -29,3 +29,39 @@ func TestCheckFileName(t *testing.T) { err := CheckFileName(w) assert.Equal(t, err, nil) } + +func TestCheckContentLength(t *testing.T) { + type args struct { + content string + max int + } + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "01", + args: args{ + content: "123", + max: 10, + }, + wantErr: false, + }, + { + name: "02", + args: args{ + content: "134234123412312323", + max: 10, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := CheckContentLength(tt.args.content, tt.args.max); (err != nil) != tt.wantErr { + t.Errorf("CheckContentLength() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/config/watcher.go b/config/watcher.go index a86bfd92a..cc6b6b4fe 100644 --- a/config/watcher.go +++ b/config/watcher.go @@ -215,7 +215,7 @@ func (wc *watchCenter) OnEvent(ctx context.Context, arg any) error { return nil } -func (wc *watchCenter) checkQuickResponseClient(watchCtx WatchContext) *apiconfig.ConfigClientResponse { +func (wc *watchCenter) CheckQuickResponseClient(watchCtx WatchContext) *apiconfig.ConfigClientResponse { buildRet := func(release *model.ConfigFileRelease) *apiconfig.ConfigClientResponse { ret := &apiconfig.ClientConfigFileInfo{ Namespace: utils.NewStringValue(release.Namespace), diff --git a/plugin/healthchecker/leader/checker_leader.go b/plugin/healthchecker/leader/checker_leader.go index 227f2c03e..e65f6bc0c 100644 --- a/plugin/healthchecker/leader/checker_leader.go +++ b/plugin/healthchecker/leader/checker_leader.go @@ -477,33 +477,6 @@ func (c *LeaderHealthChecker) isLeader() bool { return atomic.LoadInt32(&c.leader) == 1 } -const ( - errCountThreshold = 2 - maxCheckCount = 3 -) - -func (c *LeaderHealthChecker) checkLeaderAlive(ctx context.Context) { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - peer := c.findLeaderPeer() - if peer == nil { - // 可能是在 Leader 调整中,不处理探测 - continue - } - - if !peer.IsAlive() { - log.Info("[Health Check][Leader] leader peer not alive, do suspend") - c.Suspend() - } - } - } -} - func (c *LeaderHealthChecker) DebugHandlers() []model.DebugHandler { return []model.DebugHandler{ { diff --git a/plugin/healthchecker/leader/debug_test.go b/plugin/healthchecker/leader/debug_test.go new file mode 100644 index 000000000..597bd0fa6 --- /dev/null +++ b/plugin/healthchecker/leader/debug_test.go @@ -0,0 +1,148 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package leader + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + commonhash "github.com/polarismesh/polaris/common/hash" + "github.com/polarismesh/polaris/common/utils" + "github.com/stretchr/testify/assert" +) + +func Test_LeaderCheckerDebugerHandler(t *testing.T) { + t.Run("handleDescribeLeaderInfo", func(t *testing.T) { + t.Run("01_to_early", func(t *testing.T) { + leader := &LeaderHealthChecker{} + httpFunc := handleDescribeLeaderInfo(leader) + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusTooEarly, recorder.Code) + }) + + t.Run("02_self_leader", func(t *testing.T) { + leader := &LeaderHealthChecker{} + httpFunc := handleDescribeLeaderInfo(leader) + atomic.StoreInt32(&leader.leader, 1) + atomic.StoreInt32(&leader.initialize, 1) + leader.remote = &RemotePeer{ + host: "172.0.0.1", + } + + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusOK, recorder.Code) + + ret := map[string]interface{}{} + data, _ := io.ReadAll(recorder.Body) + _ = json.Unmarshal(data, &ret) + + assert.Equal(t, utils.LocalHost, ret["leader"]) + assert.Equal(t, utils.LocalHost, ret["self"]) + }) + + t.Run("03_self_follower", func(t *testing.T) { + leader := &LeaderHealthChecker{} + httpFunc := handleDescribeLeaderInfo(leader) + atomic.StoreInt32(&leader.leader, 0) + atomic.StoreInt32(&leader.initialize, 1) + leader.remote = &RemotePeer{ + host: "172.0.0.1", + } + + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusOK, recorder.Code) + + ret := map[string]interface{}{} + data, _ := io.ReadAll(recorder.Body) + _ = json.Unmarshal(data, &ret) + + assert.Equal(t, "172.0.0.1", ret["leader"]) + assert.Equal(t, utils.LocalHost, ret["self"]) + }) + }) + + t.Run("handleDescribeBeatCache", func(t *testing.T) { + t.Run("00_to_early", func(t *testing.T) { + leader := &LeaderHealthChecker{} + httpFunc := handleDescribeBeatCache(leader) + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusTooEarly, recorder.Code) + }) + + t.Run("01_self_leader", func(t *testing.T) { + leader := &LeaderHealthChecker{ + self: &LocalPeer{ + Cache: newLocalBeatRecordCache(1, commonhash.Fnv32), + }, + } + leader.self.Storage().Put(WriteBeatRecord{ + Record: RecordValue{ + Server: utils.LocalHost, + CurTimeSec: 123, + Count: 0, + }, + Key: "123", + }) + + httpFunc := handleDescribeBeatCache(leader) + atomic.StoreInt32(&leader.leader, 1) + atomic.StoreInt32(&leader.initialize, 1) + + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusOK, recorder.Code) + + ret := map[string]interface{}{} + data, _ := io.ReadAll(recorder.Body) + _ = json.Unmarshal(data, &ret) + + expectData, _ := json.Marshal(leader.self.Storage().Snapshot()) + expectRet := map[string]interface{}{} + _ = json.Unmarshal(expectData, &expectRet) + + assert.Equal(t, expectRet, ret["data"]) + assert.Equal(t, utils.LocalHost, ret["self"]) + }) + + t.Run("02_self_follower", func(t *testing.T) { + leader := &LeaderHealthChecker{} + httpFunc := handleDescribeBeatCache(leader) + atomic.StoreInt32(&leader.leader, 0) + atomic.StoreInt32(&leader.initialize, 1) + + recorder := httptest.NewRecorder() + httpFunc(recorder, httptest.NewRequest(http.MethodPost, "http://127.0.0.1:1234", nil)) + assert.Equal(t, http.StatusOK, recorder.Code) + + ret := map[string]interface{}{} + data, _ := io.ReadAll(recorder.Body) + _ = json.Unmarshal(data, &ret) + + assert.Equal(t, "Not Leader", ret["data"]) + assert.Equal(t, utils.LocalHost, ret["self"]) + }) + }) +} diff --git a/plugin/healthchecker/leader/peer.go b/plugin/healthchecker/leader/peer.go index ae05b2002..03f5b4b22 100644 --- a/plugin/healthchecker/leader/peer.go +++ b/plugin/healthchecker/leader/peer.go @@ -316,6 +316,11 @@ func (p *RemotePeer) choseOneSender() (*beatSender, error) { return p.puters[index], nil } +const ( + errCountThreshold = 2 + maxCheckCount = 3 +) + func (p *RemotePeer) checkLeaderAlive(ctx context.Context) { ticker := time.NewTicker(time.Second) for { @@ -409,6 +414,7 @@ func (p *RemotePeer) doReconnect(i int) bool { if err != nil { plog.Error("[HealthCheck][Leader] reconnect grpc-client", zap.String("host", p.Host()), zap.Uint32("port", p.port), zap.Error(err)) + _ = conn.Close() return false } @@ -491,6 +497,8 @@ func (s *beatSender) Recv() { if _, err := s.sender.Recv(); err != nil { plog.Error("[HealthCheck][Leader] receive put record result", zap.String("host", s.peer.Host()), zap.Uint32("port", s.peer.port), zap.Error(err)) + // 先关闭自己 + s.close() go s.peer.reconnect(s.index) return } diff --git a/release/standalone/docker/prometheus/build_docker_prom.sh b/release/standalone/docker/prometheus/build_docker_prom.sh index 39803ae45..89bcf77d6 100644 --- a/release/standalone/docker/prometheus/build_docker_prom.sh +++ b/release/standalone/docker/prometheus/build_docker_prom.sh @@ -7,6 +7,11 @@ fi docker_tag=$1 -echo "docker repository : polarismesh/polaris-prometheus, tag : ${docker_tag}" +docker_repository="${DOCKER_REPOSITORY}" +if [[ "${docker_repository}" == "" ]]; then + docker_repository="polarismesh" +fi + +echo "docker repository : ${docker_repository}/polaris-prometheus, tag : ${docker_tag}" -docker buildx build --network=host -t polarismesh/polaris-prometheus:${docker_tag} -t polarismesh/polaris-prometheus:latest --platform linux/amd64,linux/arm64 --push ./ +docker buildx build --network=host -t ${docker_repository}/polaris-prometheus:${docker_tag} -t ${docker_repository}/polaris-prometheus:latest --platform linux/amd64,linux/arm64 --push ./ diff --git a/service/api.go b/service/api.go index d6cc37ad1..b37f19f86 100644 --- a/service/api.go +++ b/service/api.go @@ -22,7 +22,7 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/model" ) @@ -39,7 +39,7 @@ type DiscoverServer interface { // ClientServer Client operation interface definition ClientServer // Cache Get cache management - Cache() *cache.CacheManager + Cache() cachetypes.CacheManager // L5OperateServer L5 related operations L5OperateServer // GetServiceInstanceRevision Get the version of the service diff --git a/service/batch/batch_test.go b/service/batch/batch_test.go index 81f3212f3..c699a9607 100644 --- a/service/batch/batch_test.go +++ b/service/batch/batch_test.go @@ -25,7 +25,6 @@ import ( "github.com/golang/mock/gomock" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" - . "github.com/smartystreets/goconvey/convey" "github.com/stretchr/testify/assert" "github.com/polarismesh/polaris/common/metrics" @@ -40,7 +39,7 @@ func init() { // TestNewBatchCtrlWithConfig 测试New func TestNewBatchCtrlWithConfig(t *testing.T) { - Convey("正常新建", t, func() { + t.Run("正常新建", func(t *testing.T) { ctrlConfig := &CtrlConfig{ Open: true, QueueSize: 1024, @@ -53,25 +52,25 @@ func TestNewBatchCtrlWithConfig(t *testing.T) { Deregister: ctrlConfig, } bc, err := NewBatchCtrlWithConfig(nil, nil, config) - So(err, ShouldBeNil) - So(bc, ShouldNotBeNil) - So(bc.register, ShouldNotBeNil) - So(bc.deregister, ShouldNotBeNil) + assert.Nil(t, err) + assert.NotNil(t, bc) + assert.NotNil(t, bc.register) + assert.NotNil(t, bc.deregister) }) - Convey("可以关闭register和deregister的batch操作", t, func() { + t.Run("可以关闭register和deregister的batch操作", func(t *testing.T) { bc, err := NewBatchCtrlWithConfig(nil, nil, nil) - So(err, ShouldBeNil) - So(bc, ShouldBeNil) + assert.Nil(t, err) + assert.Nil(t, bc) config := &Config{ Register: &CtrlConfig{Open: false}, Deregister: &CtrlConfig{Open: false}, } bc, err = NewBatchCtrlWithConfig(nil, nil, config) - So(err, ShouldBeNil) - So(bc, ShouldNotBeNil) - So(bc.register, ShouldBeNil) - So(bc.deregister, ShouldBeNil) + assert.Nil(t, err) + assert.NotNil(t, bc) + assert.Nil(t, bc.register) + assert.Nil(t, bc.deregister) }) } @@ -146,13 +145,22 @@ func TestAsyncCreateInstance(t *testing.T) { // TestSendReply 测试reply func TestSendReply(t *testing.T) { - Convey("可以正常获取类型", t, func() { + t.Run("可以正常获取类型", func(t *testing.T) { sendReply(make([]*InstanceFuture, 0, 10), 1, nil) }) - Convey("可以正常获取类型2", t, func() { + t.Run("可以正常获取类型2", func(t *testing.T) { sendReply(make(map[string]*InstanceFuture, 10), 1, nil) }) - Convey("其他类型不通过", t, func() { + t.Run("其他类型不通过", func(t *testing.T) { sendReply("test string", 1, nil) }) + t.Run("可以正常获取类型", func(t *testing.T) { + SendClientReply(make([]*ClientFuture, 0, 10), 1, nil) + }) + t.Run("可以正常获取类型2", func(t *testing.T) { + SendClientReply(make(map[string]*ClientFuture, 10), 1, nil) + }) + t.Run("其他类型不通过", func(t *testing.T) { + SendClientReply("test string", 1, nil) + }) } diff --git a/service/batch/client_future_test.go b/service/batch/client_future_test.go new file mode 100644 index 000000000..4b569c9f7 --- /dev/null +++ b/service/batch/client_future_test.go @@ -0,0 +1,32 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package batch + +import ( + "testing" + + "github.com/polarismesh/polaris/common/model" + "github.com/stretchr/testify/assert" +) + +func TestClientFuture_SetClient(t *testing.T) { + f := &ClientFuture{} + + f.SetClient(&model.Client{}) + assert.NotNil(t, f.Client()) +} diff --git a/service/batch/config_test.go b/service/batch/config_test.go new file mode 100644 index 000000000..ace25d45c --- /dev/null +++ b/service/batch/config_test.go @@ -0,0 +1,144 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package batch + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseBatchConfig(t *testing.T) { + type args struct { + opt map[string]interface{} + } + tests := []struct { + name string + args args + want *Config + wantErr bool + }{ + { + name: "opt_nil", + args: args{ + opt: nil, + }, + want: nil, + wantErr: false, + }, + { + name: "register", + args: args{ + opt: map[string]interface{}{ + "register": map[string]interface{}{ + "queueSize": 0, + "maxBatchCount": 0, + }, + }, + }, + want: nil, + wantErr: true, + }, + { + name: "deregister", + args: args{ + opt: map[string]interface{}{ + "deregister": map[string]interface{}{ + "queueSize": 0, + "maxBatchCount": 0, + }, + }, + }, + want: nil, + wantErr: true, + }, + { + name: "clientRegister", + args: args{ + opt: map[string]interface{}{ + "clientRegister": map[string]interface{}{ + "queueSize": 0, + "maxBatchCount": 0, + }, + }, + }, + want: nil, + wantErr: true, + }, + { + name: "clientDeregister", + args: args{ + opt: map[string]interface{}{ + "clientDeregister": map[string]interface{}{ + "queueSize": 0, + "maxBatchCount": 0, + }, + }, + }, + want: nil, + wantErr: true, + }, + { + name: "heartbeat", + args: args{ + opt: map[string]interface{}{ + "heartbeat": map[string]interface{}{ + "queueSize": 0, + "maxBatchCount": 0, + }, + }, + }, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseBatchConfig(tt.args.opt) + if (err != nil) != tt.wantErr { + t.Errorf("ParseBatchConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseBatchConfig() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCheckCtrlConfig(t *testing.T) { + // 测试有效配置 + validCtrl := &CtrlConfig{ + QueueSize: 10, + MaxBatchCount: 100, + Concurrency: 5, + } + assert.True(t, checkCtrlConfig(validCtrl)) + + // 测试无效配置 + invalidCtrl := &CtrlConfig{ + QueueSize: 0, + MaxBatchCount: 0, + Concurrency: 0, + } + assert.False(t, checkCtrlConfig(invalidCtrl)) + + // 测试nil配置 + assert.True(t, checkCtrlConfig(nil)) +} diff --git a/service/circuitbreaker_rule.go b/service/circuitbreaker_rule.go index 139831e0a..fbc54b9f4 100644 --- a/service/circuitbreaker_rule.go +++ b/service/circuitbreaker_rule.go @@ -36,25 +36,9 @@ import ( "github.com/polarismesh/polaris/common/utils" ) -func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - // CreateCircuitBreakerRules Create a CircuitBreaker rule func (s *Server) CreateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - if checkErr := checkBatchCircuitBreakerRules(request); checkErr != nil { - return checkErr - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, cbRule := range request { response := s.createCircuitBreakerRule(ctx, cbRule) @@ -182,10 +166,6 @@ var ( // DeleteCircuitBreakerRules Delete current CircuitBreaker rules func (s *Server) DeleteCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { resp := s.deleteCircuitBreakerRule(ctx, entry) @@ -227,10 +207,6 @@ func (s *Server) deleteCircuitBreakerRule( // EnableCircuitBreakerRules Enable the CircuitBreaker rule func (s *Server) EnableCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { resp := s.enableCircuitBreakerRule(ctx, entry) @@ -273,10 +249,6 @@ func (s *Server) enableCircuitBreakerRule( // UpdateCircuitBreakerRules Modify the CircuitBreaker rule func (s *Server) UpdateCircuitBreakerRules( ctx context.Context, request []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { - if err := checkBatchCircuitBreakerRules(request); err != nil { - return err - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, entry := range request { response := s.updateCircuitBreakerRule(ctx, entry) diff --git a/service/circuitbreaker_rule_test.go b/service/circuitbreaker_rule_test.go index b24dd58b3..07d9df31a 100644 --- a/service/circuitbreaker_rule_test.go +++ b/service/circuitbreaker_rule_test.go @@ -30,6 +30,8 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/stretchr/testify/assert" + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/service" ) @@ -171,6 +173,26 @@ func TestCreateCircuitBreakerRule(t *testing.T) { } defer discoverSuit.Destroy() + t.Run("abnormal_scene", func(t *testing.T) { + t.Run("empty_request", func(t *testing.T) { + resp := discoverSuit.DiscoverServer().CreateCircuitBreakerRules(discoverSuit.DefaultCtx, []*apifault.CircuitBreakerRule{}) + assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_EmptyRequest), resp.GetCode().GetValue()) + }) + + t.Run("too_many_request", func(t *testing.T) { + requests := []*apifault.CircuitBreakerRule{} + for i := 0; i < utils.MaxBatchSize+10; i++ { + requests = append(requests, &apifault.CircuitBreakerRule{ + Id: "123123", + }) + } + resp := discoverSuit.DiscoverServer().CreateCircuitBreakerRules(discoverSuit.DefaultCtx, requests) + assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_BatchSizeOverLimit), resp.GetCode().GetValue(), resp.GetInfo().GetValue()) + }) + }) + t.Run("正常创建熔断规则,返回成功", func(t *testing.T) { cbRules, resp := createCircuitBreakerRules(discoverSuit, testCount) defer cleanCircuitBreakerRules(discoverSuit, resp) diff --git a/service/client_info.go b/service/client_info.go index 4decff73e..bf2ea3854 100644 --- a/service/client_info.go +++ b/service/client_info.go @@ -48,7 +48,7 @@ func (s *Server) checkAndStoreClient(ctx context.Context, req *apiservice.Client if nil == client { needStore = true } else { - needStore = !clientEquals(client.Proto(), req) + needStore = !ClientEquals(client.Proto(), req) } if needStore { client, resp = s.createClient(ctx, req) @@ -155,7 +155,7 @@ func client2Api(client *model.Client) *apiservice.Client { return out } -func clientEquals(client1 *apiservice.Client, client2 *apiservice.Client) bool { +func ClientEquals(client1 *apiservice.Client, client2 *apiservice.Client) bool { if client1.GetId().GetValue() != client2.GetId().GetValue() { return false } diff --git a/service/client_info_test.go b/service/client_info_test.go new file mode 100644 index 000000000..cab1fc126 --- /dev/null +++ b/service/client_info_test.go @@ -0,0 +1,757 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package service_test + +import ( + "context" + "sync" + "testing" + + "github.com/golang/mock/gomock" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/types/known/wrapperspb" + + "github.com/polarismesh/polaris/cache" + api "github.com/polarismesh/polaris/common/api/v1" + apiv1 "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/service" + "github.com/polarismesh/polaris/service/batch" + "github.com/polarismesh/polaris/store" + "github.com/polarismesh/polaris/store/mock" +) + +func mockReportClients(cnt int) []*apiservice.Client { + ret := make([]*apiservice.Client, 0, 4) + + for i := 0; i < cnt; i++ { + ret = append(ret, &apiservice.Client{ + Host: utils.NewStringValue("127.0.0.1"), + Type: apiservice.Client_SDK, + Version: utils.NewStringValue("v1.0.0"), + Location: &apimodel.Location{}, + Id: utils.NewStringValue(utils.NewUUID()), + Stat: []*apiservice.StatInfo{ + { + Target: utils.NewStringValue(model.StatReportPrometheus), + Port: utils.NewUInt32Value(uint32(1000 + i)), + Path: utils.NewStringValue("/metrics"), + Protocol: utils.NewStringValue("http"), + }, + }, + }) + } + + return ret +} + +func TestServer_ReportClient(t *testing.T) { + t.Run("正常客户端上报", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + discoverSuit.cleanReportClient() + discoverSuit.Destroy() + }) + + clients := mockReportClients(1) + + for i := range clients { + resp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, clients[i]) + assert.True(t, respSuccess(resp), resp.GetInfo().GetValue()) + } + }) + + t.Run("abnormal_scene", func(t *testing.T) { + + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + discoverSuit.cleanReportClient() + discoverSuit.Destroy() + }) + + client := mockReportClients(1)[0] + + svr := discoverSuit.OriginDiscoverServer().(*service.Server) + + t.Run("01_store_err", func(t *testing.T) { + ctrl := gomock.NewController(t) + oldStore := svr.Store() + oldBc := svr.GetBatchController() + + ctx, cancel := context.WithCancel(context.Background()) + defer func() { + cancel() + ctrl.Finish() + svr.MockBatchController(oldBc) + svr.TestSetStore(oldStore) + }() + + mockStore := mock.NewMockStore(ctrl) + mockBc, err := batch.NewBatchCtrlWithConfig(mockStore, discoverSuit.CacheMgr(), &batch.Config{ + ClientRegister: &batch.CtrlConfig{ + Open: true, + QueueSize: 1, + WaitTime: "32ms", + Concurrency: 1, + MaxBatchCount: 4, + }, + }) + assert.NoError(t, err) + mockBc.Start(ctx) + + svr.TestSetStore(mockStore) + svr.MockBatchController(mockBc) + + mockStore.EXPECT().BatchAddClients(gomock.Any()).Return(store.NewStatusError(store.Unknown, "mock error")).AnyTimes() + + rsp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, client) + assert.False(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + + t.Run("02_exist_resource", func(t *testing.T) { + ctrl := gomock.NewController(t) + oldStore := svr.Store() + defer func() { + ctrl.Finish() + svr.TestSetStore(oldStore) + }() + + mockStore := mock.NewMockStore(ctrl) + svr.TestSetStore(mockStore) + + mockStore.EXPECT().BatchAddClients(gomock.Any()).Return(store.NewStatusError(store.DuplicateEntryErr, "mock error")).AnyTimes() + + rsp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, client) + assert.True(t, api.IsSuccess(rsp), rsp.GetInfo().GetValue()) + }) + }) +} + +func TestServer_GetReportClient(t *testing.T) { + t.Run("客户端上报-查询客户端信息", func(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + // 主动触发清理之前的 ReportClient 数据 + discoverSuit.cleanReportClient() + // 强制触发缓存更新 + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() + t.Log("finish sleep to wait cache refresh") + + t.Cleanup(func() { + discoverSuit.cleanReportClient() + discoverSuit.Destroy() + }) + + clients := mockReportClients(5) + + wait := sync.WaitGroup{} + wait.Add(5) + for i := range clients { + go func(client *apiservice.Client) { + defer wait.Done() + resp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, client) + assert.True(t, respSuccess(resp), resp.GetInfo().GetValue()) + t.Logf("create one client success : %s", client.GetId().GetValue()) + }(clients[i]) + } + + wait.Wait() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() + t.Log("finish sleep to wait cache refresh") + + resp := discoverSuit.DiscoverServer().GetPrometheusTargets(context.Background(), map[string]string{}) + t.Logf("get report clients result: %#v", resp) + assert.Equal(t, apiv1.ExecuteSuccess, resp.Code) + }) +} + +func TestServer_GetReportClients(t *testing.T) { + discoverSuit := &DiscoverTestSuit{} + if err := discoverSuit.Initialize(); err != nil { + t.Fatal(err) + } + + t.Run("create client", func(t *testing.T) { + svr := discoverSuit.OriginDiscoverServer() + + mockClientId := utils.NewUUID() + resp := svr.ReportClient(context.Background(), &service_manage.Client{ + Host: utils.NewStringValue("127.0.0.1"), + Type: service_manage.Client_SDK, + Version: utils.NewStringValue("1.0.0"), + Location: &apimodel.Location{ + Region: utils.NewStringValue("region"), + Zone: utils.NewStringValue("zone"), + Campus: utils.NewStringValue("campus"), + }, + Id: utils.NewStringValue(mockClientId), + Stat: []*service_manage.StatInfo{ + { + Target: utils.NewStringValue("prometheus"), + Port: utils.NewUInt32Value(8080), + Path: utils.NewStringValue("/metrics"), + Protocol: utils.NewStringValue("http"), + }, + }, + }) + + assert.Equal(t, resp.GetCode().GetValue(), uint32(apimodel.Code_ExecuteSuccess)) + // 强制刷新到 cache + svr.Cache().(*cache.CacheManager).TestUpdate() + + originSvr := discoverSuit.OriginDiscoverServer().(*service.Server) + qresp := originSvr.GetReportClients(discoverSuit.DefaultCtx, map[string]string{}) + assert.Equal(t, resp.GetCode().GetValue(), uint32(apimodel.Code_ExecuteSuccess)) + assert.Equal(t, qresp.GetAmount().GetValue(), uint32(1)) + assert.Equal(t, qresp.GetSize().GetValue(), uint32(1)) + }) + + t.Run("invalid_search", func(t *testing.T) { + originSvr := discoverSuit.OriginDiscoverServer().(*service.Server) + resp := originSvr.GetReportClients(discoverSuit.DefaultCtx, map[string]string{ + "offset": "abc", + }) + + assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_InvalidParameter), resp.GetCode().GetValue()) + + resp = originSvr.GetReportClients(discoverSuit.DefaultCtx, map[string]string{ + "version_123": "abc", + }) + + assert.False(t, api.IsSuccess(resp), resp.GetInfo().GetValue()) + assert.Equal(t, uint32(apimodel.Code_InvalidParameter), resp.GetCode().GetValue()) + + }) +} + +func Test_clientEquals(t *testing.T) { + type args struct { + client1 *apiservice.Client + client2 *apiservice.Client + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "full_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: true, + }, + { + name: "id_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("2"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "host_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("2.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "version_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "region_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region-1"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "zone_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone-1"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "campus_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus-1"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "stat_target_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus-1"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "stat_port_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28081), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "stat_path_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/v1/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + { + name: "stat_protocol_not_equal", + args: args{ + client1: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.1.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("tcp"), + }, + }, + }, + client2: &apiservice.Client{ + Id: wrapperspb.String("1"), + Host: wrapperspb.String("1.1.1.1"), + Version: wrapperspb.String("Java-1.0.0"), + Type: apiservice.Client_SDK, + Location: &apimodel.Location{ + Region: wrapperspb.String("region"), + Zone: wrapperspb.String("zone"), + Campus: wrapperspb.String("campus"), + }, + Stat: []*apiservice.StatInfo{ + { + Target: wrapperspb.String("prometheus"), + Port: wrapperspb.UInt32(28080), + Path: wrapperspb.String("/metrics"), + Protocol: wrapperspb.String("http"), + }, + }, + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := service.ClientEquals(tt.args.client1, tt.args.client2); got != tt.want { + t.Errorf("clientEquals() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/service/client_test.go b/service/client_test.go index d0f428910..625f44dad 100644 --- a/service/client_test.go +++ b/service/client_test.go @@ -18,22 +18,17 @@ package service_test import ( - "context" "fmt" - "sync" "testing" "github.com/golang/protobuf/proto" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" - "github.com/polarismesh/specification/source/go/api/v1/service_manage" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/stretchr/testify/assert" + "github.com/polarismesh/polaris/cache" api "github.com/polarismesh/polaris/common/api/v1" - apiv1 "github.com/polarismesh/polaris/common/api/v1" - "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" - "github.com/polarismesh/polaris/service" ) // 测试discover instances @@ -62,7 +57,7 @@ func TestDiscoverInstances(t *testing.T) { reqInstances = append(reqInstances, req) } t.Run("正常服务发现,返回的数据齐全", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, service) assert.True(t, respSuccess(out)) assert.Equal(t, count, len(out.GetInstances())) @@ -92,7 +87,7 @@ func TestDiscoverInstances(t *testing.T) { service.Metadata["new-metadata1"] = "1233" service.Metadata["new-metadata2"] = "2342" resp := discoverSuit.DiscoverServer().UpdateServices(discoverSuit.DefaultCtx, []*apiservice.Service{service}) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() assert.True(t, respSuccess(resp)) assert.NotEqual(t, resp.Responses[0].GetService().GetRevision().GetValue(), oldRevision) assert.Equal(t, resp.Responses[0].GetService().GetMetadata()["new-metadata1"], "1233") @@ -116,7 +111,7 @@ func TestDiscoverCircuitBreaker(t *testing.T) { defer cleanCircuitBreakerRules(discoverSuit, resp) service := &apiservice.Service{Name: utils.NewStringValue("testDestService"), Namespace: utils.NewStringValue("test")} t.Run("正常获取熔断规则", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetCircuitBreakerWithCache(discoverSuit.DefaultCtx, service) assert.True(t, respSuccess(out)) assert.Equal(t, len(out.GetCircuitBreaker().GetRules()), len(rules)) @@ -145,7 +140,7 @@ func TestDiscoverCircuitBreaker2(t *testing.T) { defer cleanCircuitBreakerRules(discoverSuit, resp) service := &apiservice.Service{Name: utils.NewStringValue("testDestService"), Namespace: utils.NewStringValue("default")} t.Run("熔断规则不存在", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetCircuitBreakerWithCache(discoverSuit.DefaultCtx, service) assert.True(t, respSuccess(out)) assert.Equal(t, 0, len(out.GetCircuitBreaker().GetRules())) @@ -192,7 +187,7 @@ func TestDiscoverService(t *testing.T) { expectService2.Metadata = meta _ = discoverSuit.DiscoverServer().UpdateServices(discoverSuit.DefaultCtx, []*apiservice.Service{expectService1}) _ = discoverSuit.DiscoverServer().UpdateServices(discoverSuit.DefaultCtx, []*apiservice.Service{expectService1}) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() t.Run("正常获取服务", func(t *testing.T) { requestService := &apiservice.Service{ @@ -267,7 +262,7 @@ func TestDiscoverRateLimits(t *testing.T) { defer discoverSuit.cleanRateLimit(rateLimitResp.GetId().GetValue()) defer discoverSuit.cleanRateLimitRevision(service.GetName().GetValue(), service.GetNamespace().GetValue()) t.Run("正常获取限流规则", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetRateLimitWithCache(discoverSuit.DefaultCtx, service) assert.True(t, respSuccess(out)) assert.Equal(t, len(out.GetRateLimit().GetRules()), 1) @@ -281,7 +276,7 @@ func TestDiscoverRateLimits(t *testing.T) { }) t.Run("限流规则已删除", func(t *testing.T) { discoverSuit.deleteRateLimit(t, rateLimitResp) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetRateLimitWithCache(discoverSuit.DefaultCtx, service) assert.True(t, respSuccess(out)) assert.Equal(t, len(out.GetRateLimit().GetRules()), 0) @@ -303,14 +298,14 @@ func TestDiscoverRateLimits2(t *testing.T) { _, service := discoverSuit.createCommonService(t, 1) defer discoverSuit.cleanServiceName(service.GetName().GetValue(), service.GetNamespace().GetValue()) t.Run("限流规则不存在", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetRateLimitWithCache(discoverSuit.DefaultCtx, service) assert.True(t, respSuccess(out)) assert.Nil(t, out.GetRateLimit()) t.Logf("pass: out is %+v", out) }) t.Run("服务不存在", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetRateLimitWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue("not_exist_service"), Namespace: utils.NewStringValue("not_exist_namespace"), @@ -321,132 +316,6 @@ func TestDiscoverRateLimits2(t *testing.T) { }) } -func mockReportClients(cnt int) []*apiservice.Client { - ret := make([]*apiservice.Client, 0, 4) - - for i := 0; i < cnt; i++ { - ret = append(ret, &apiservice.Client{ - Host: utils.NewStringValue("127.0.0.1"), - Type: apiservice.Client_SDK, - Version: utils.NewStringValue("v1.0.0"), - Location: &apimodel.Location{}, - Id: utils.NewStringValue(utils.NewUUID()), - Stat: []*apiservice.StatInfo{ - { - Target: utils.NewStringValue(model.StatReportPrometheus), - Port: utils.NewUInt32Value(uint32(1000 + i)), - Path: utils.NewStringValue("/metrics"), - Protocol: utils.NewStringValue("http"), - }, - }, - }) - } - - return ret -} - -func TestServer_ReportClient(t *testing.T) { - t.Run("正常客户端上报", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - discoverSuit.cleanReportClient() - discoverSuit.Destroy() - }) - - clients := mockReportClients(1) - - for i := range clients { - resp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, clients[i]) - assert.True(t, respSuccess(resp), resp.GetInfo().GetValue()) - } - }) -} - -func TestServer_GetReportClient(t *testing.T) { - t.Run("客户端上报-查询客户端信息", func(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - // 主动触发清理之前的 ReportClient 数据 - discoverSuit.cleanReportClient() - // 强制触发缓存更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() - t.Log("finish sleep to wait cache refresh") - - t.Cleanup(func() { - discoverSuit.cleanReportClient() - discoverSuit.Destroy() - }) - - clients := mockReportClients(5) - - wait := sync.WaitGroup{} - wait.Add(5) - for i := range clients { - go func(client *apiservice.Client) { - defer wait.Done() - resp := discoverSuit.DiscoverServer().ReportClient(discoverSuit.DefaultCtx, client) - assert.True(t, respSuccess(resp), resp.GetInfo().GetValue()) - t.Logf("create one client success : %s", client.GetId().GetValue()) - }(clients[i]) - } - - wait.Wait() - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() - t.Log("finish sleep to wait cache refresh") - - resp := discoverSuit.DiscoverServer().GetPrometheusTargets(context.Background(), map[string]string{}) - t.Logf("get report clients result: %#v", resp) - assert.Equal(t, apiv1.ExecuteSuccess, resp.Code) - }) -} - -func TestServer_GetReportClients(t *testing.T) { - discoverSuit := &DiscoverTestSuit{} - if err := discoverSuit.Initialize(); err != nil { - t.Fatal(err) - } - - t.Run("create client", func(t *testing.T) { - svr := discoverSuit.OriginDiscoverServer() - - mockClientId := utils.NewUUID() - resp := svr.ReportClient(context.Background(), &service_manage.Client{ - Host: utils.NewStringValue("127.0.0.1"), - Type: service_manage.Client_SDK, - Version: utils.NewStringValue("1.0.0"), - Location: &apimodel.Location{ - Region: utils.NewStringValue("region"), - Zone: utils.NewStringValue("zone"), - Campus: utils.NewStringValue("campus"), - }, - Id: utils.NewStringValue(mockClientId), - Stat: []*service_manage.StatInfo{ - { - Target: utils.NewStringValue("prometheus"), - Port: utils.NewUInt32Value(8080), - Path: utils.NewStringValue("/metrics"), - Protocol: utils.NewStringValue("http"), - }, - }, - }) - - assert.Equal(t, resp.GetCode().GetValue(), uint32(apimodel.Code_ExecuteSuccess)) - // 强制刷新到 cache - svr.Cache().TestUpdate() - - originSvr := discoverSuit.OriginDiscoverServer().(*service.Server) - qresp := originSvr.GetReportClients(discoverSuit.DefaultCtx, map[string]string{}) - assert.Equal(t, resp.GetCode().GetValue(), uint32(apimodel.Code_ExecuteSuccess)) - assert.Equal(t, qresp.GetAmount().GetValue(), uint32(1)) - assert.Equal(t, qresp.GetSize().GetValue(), uint32(1)) - }) -} - // TestServer_ReportServiceContract 测试上报服务合约 func TestServer_ReportServiceContract(t *testing.T) { discoverSuit := &DiscoverTestSuit{} diff --git a/service/client_v1.go b/service/client_v1.go index c40b845c1..0c08fc0b5 100644 --- a/service/client_v1.go +++ b/service/client_v1.go @@ -38,19 +38,16 @@ import ( // RegisterInstance create one instance func (s *Server) RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - ctx = context.WithValue(ctx, utils.ContextIsFromClient, true) return s.CreateInstance(ctx, req) } // DeregisterInstance delete one instance func (s *Server) DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - ctx = context.WithValue(ctx, utils.ContextIsFromClient, true) return s.DeleteInstance(ctx, req) } // ReportServiceContract report client service interface info func (s *Server) ReportServiceContract(ctx context.Context, req *apiservice.ServiceContract) *apiservice.Response { - ctx = context.WithValue(ctx, utils.ContextIsFromClient, true) cacheData := s.caches.ServiceContract().Get(ctx, &model.ServiceContract{ Namespace: req.GetNamespace(), Service: req.GetService(), diff --git a/service/default.go b/service/default.go index 44493e657..5609dc9e9 100644 --- a/service/default.go +++ b/service/default.go @@ -48,7 +48,7 @@ const ( DefaultTLL = 5 ) -type ServerProxyFactory func(svr *Server, pre DiscoverServer) (DiscoverServer, error) +type ServerProxyFactory func(pre DiscoverServer) (DiscoverServer, error) var ( server DiscoverServer @@ -79,7 +79,7 @@ type Config struct { func Initialize(ctx context.Context, namingOpt *Config, opts ...InitOption) error { var err error once.Do(func() { - err = initialize(ctx, namingOpt, opts...) + namingServer, server, err = InitServer(ctx, namingOpt, opts...) }) if err != nil { @@ -109,36 +109,39 @@ func GetOriginServer() (*Server, error) { } // 内部初始化函数 -func initialize(ctx context.Context, namingOpt *Config, opts ...InitOption) error { +func InitServer(ctx context.Context, namingOpt *Config, opts ...InitOption) (*Server, DiscoverServer, error) { + actualSvr := new(Server) // l5service - namingServer.config = *namingOpt - namingServer.l5service = &l5service{} - namingServer.instanceChains = make([]InstanceChain, 0, 4) - namingServer.createServiceSingle = &singleflight.Group{} - namingServer.subCtxs = make([]*eventhub.SubscribtionContext, 0, 4) + actualSvr.config = *namingOpt + actualSvr.l5service = &l5service{} + actualSvr.instanceChains = make([]InstanceChain, 0, 4) + actualSvr.createServiceSingle = &singleflight.Group{} + actualSvr.subCtxs = make([]*eventhub.SubscribtionContext, 0, 4) for i := range opts { - opts[i](namingServer) + opts[i](actualSvr) } // 插件初始化 - pluginInitialize() + actualSvr.pluginInitialize() + var proxySvr DiscoverServer + proxySvr = actualSvr // 需要返回包装代理的 DiscoverServer order := namingOpt.Interceptors for i := range order { factory, exist := serverProxyFactories[order[i]] if !exist { - return fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) + return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) } - proxySvr, err := factory(namingServer, server) + afterSvr, err := factory(proxySvr) if err != nil { - return err + return nil, nil, err } - server = proxySvr + proxySvr = afterSvr } - return nil + return actualSvr, proxySvr, nil } type PluginInstanceEventHandler struct { @@ -153,25 +156,19 @@ func (p *PluginInstanceEventHandler) OnEvent(ctx context.Context, any2 any) erro } // 插件初始化 -func pluginInitialize() { +func (svr *Server) pluginInitialize() { // 获取CMDB插件 - namingServer.cmdb = plugin.GetCMDB() - if namingServer.cmdb == nil { + svr.cmdb = plugin.GetCMDB() + if svr.cmdb == nil { log.Warnf("Not Found CMDB Plugin") } // 获取History插件,注意:插件的配置在bootstrap已经设置好 - namingServer.history = plugin.GetHistory() - if namingServer.history == nil { + svr.history = plugin.GetHistory() + if svr.history == nil { log.Warnf("Not Found History Log Plugin") } - // 获取限流插件 - namingServer.ratelimit = plugin.GetRatelimit() - if namingServer.ratelimit == nil { - log.Warnf("Not found Ratelimit Plugin") - } - subscriber := plugin.GetDiscoverEvent() if subscriber == nil { log.Warnf("Not found DiscoverEvent Plugin") @@ -179,18 +176,19 @@ func pluginInitialize() { } eventHandler := &PluginInstanceEventHandler{ - BaseInstanceEventHandler: NewBaseInstanceEventHandler(namingServer), + BaseInstanceEventHandler: NewBaseInstanceEventHandler(svr), subscriber: subscriber, } subCtx, err := eventhub.Subscribe(eventhub.InstanceEventTopic, eventHandler) if err != nil { log.Warnf("register DiscoverEvent into eventhub:%s %v", subscriber.Name(), err) } - namingServer.subCtxs = append(namingServer.subCtxs, subCtx) + svr.subCtxs = append(svr.subCtxs, subCtx) } func GetChainOrder() []string { return []string{ "auth", + "paramcheck", } } diff --git a/service/default_test.go b/service/default_test.go index 1c21dc3e4..24d2da235 100644 --- a/service/default_test.go +++ b/service/default_test.go @@ -21,12 +21,17 @@ import ( "context" "sync" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/polarismesh/polaris/auth" + "github.com/polarismesh/polaris/cache" cachemock "github.com/polarismesh/polaris/cache/mock" + "github.com/polarismesh/polaris/namespace" + "github.com/polarismesh/polaris/service/batch" + "github.com/polarismesh/polaris/service/healthcheck" "github.com/polarismesh/polaris/store/mock" ) @@ -40,6 +45,8 @@ func Test_Initialize(t *testing.T) { s := mock.NewMockStore(ctrl) cacheMgr := cachemock.NewMockCacheManager(ctrl) cacheMgr.EXPECT().OpenResourceCache(gomock.Any()).Return(nil).AnyTimes() + cacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + cacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() _, _, err := auth.TestInitialize(context.Background(), &auth.Config{ Option: map[string]interface{}{}, @@ -59,3 +66,43 @@ func Test_Initialize(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, dSvr) } + +func Test_Server(t *testing.T) { + t.Run("cache_entries", func(t *testing.T) { + ret := GetAllCaches() + assert.True(t, len(ret) > 0) + }) + + t.Run("with_test", func(t *testing.T) { + ctrl := gomock.NewController(t) + t.Cleanup(func() { + ctrl.Finish() + }) + + svr := &Server{} + + opt := []InitOption{} + + mockCacheMgr := cachemock.NewMockCacheManager(ctrl) + mockCacheMgr.EXPECT().OpenResourceCache(gomock.Any()).Return(nil).AnyTimes() + mockCacheMgr.EXPECT().GetReportInterval().Return(time.Second).AnyTimes() + mockCacheMgr.EXPECT().GetUpdateCacheInterval().Return(time.Second).AnyTimes() + + opt = append(opt, WithBatchController(&batch.Controller{})) + opt = append(opt, WithNamespaceSvr(&namespace.Server{})) + opt = append(opt, WithCacheManager(&cache.Config{}, mockCacheMgr)) + opt = append(opt, WithHealthCheckSvr(&healthcheck.Server{})) + opt = append(opt, WithStorage(mock.NewMockStore(ctrl))) + + for i := range opt { + opt[i](svr) + } + + assert.NotNil(t, svr.bc) + assert.NotNil(t, svr.namespaceSvr) + assert.NotNil(t, svr.caches) + assert.NotNil(t, svr.healthServer) + assert.NotNil(t, svr.storage) + + }) +} diff --git a/service/instance.go b/service/instance.go index 22415d238..13681b7f7 100644 --- a/service/instance.go +++ b/service/instance.go @@ -79,33 +79,16 @@ var ( // CreateInstances 批量创建服务实例 func (s *Server) CreateInstances(ctx context.Context, reqs []*apiservice.Instance) *apiservice.BatchWriteResponse { - if checkError := checkBatchInstance(reqs); checkError != nil { - return checkError - } - return batchOperateInstances(ctx, reqs, s.CreateInstance) } // CreateInstance create a single service instance func (s *Server) CreateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) start := time.Now() - instanceID, checkError := checkCreateInstance(req) - if checkError != nil { - return checkError - } - // Restricted Instance frequently registered - if ok := s.allowInstanceAccess(instanceID); !ok { - log.Error("create instance not allowed to access: exceed ratelimit", - utils.ZapRequestID(rid), utils.ZapPlatformID(pid), utils.ZapInstanceID(instanceID)) - return api.NewInstanceResponse(apimodel.Code_InstanceTooManyRequests, req) - } // Prevent pollution api.Instance struct, copy and fill token ins := *req ins.ServiceToken = utils.NewStringValue(parseInstanceReqToken(ctx, req)) - ins.Id = utils.NewStringValue(instanceID) data, resp := s.createInstance(ctx, req, &ins) if resp != nil { return resp @@ -114,14 +97,14 @@ func (s *Server) CreateInstance(ctx context.Context, req *apiservice.Instance) * msg := fmt.Sprintf("create instance: id=%v, namespace=%v, service=%v, host=%v, port=%v", ins.GetId().GetValue(), req.GetNamespace().GetValue(), req.GetService().GetValue(), req.GetHost().GetValue(), req.GetPort().GetValue()) - log.Info(msg, utils.ZapRequestID(rid), utils.ZapPlatformID(pid), zap.Duration("cost", time.Since(start))) + log.Info(msg, utils.RequestID(ctx), zap.Duration("cost", time.Since(start))) svc := &model.Service{ Name: req.GetService().GetValue(), Namespace: req.GetNamespace().GetValue(), } instanceProto := data.Proto event := &model.InstanceEvent{ - Id: instanceID, + Id: req.GetId().GetValue(), Namespace: svc.Namespace, Service: svc.Name, Instance: instanceProto, @@ -191,13 +174,11 @@ func (s *Server) asyncCreateInstance( func (s *Server) serialCreateInstance( ctx context.Context, svcId string, req *apiservice.Instance, ins *apiservice.Instance) ( *model.Instance, *apiservice.Response) { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) instance, err := s.storage.GetInstance(ins.GetId().GetValue()) if err != nil { log.Error("[Instance] get instance from store", - utils.ZapRequestID(rid), utils.ZapPlatformID(pid), zap.Error(err)) + utils.RequestID(ctx), zap.Error(err)) return nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } // 如果存在,则替换实例的属性数据,但是需要保留用户设置的隔离状态,以免出现关键状态丢失 @@ -207,7 +188,7 @@ func (s *Server) serialCreateInstance( // 直接同步创建服务实例 data := model.CreateInstanceModel(svcId, ins) if err := s.storage.AddInstance(data); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, wrapperInstanceStoreResponse(req, err) } @@ -216,31 +197,12 @@ func (s *Server) serialCreateInstance( // DeleteInstances 批量删除服务实例 func (s *Server) DeleteInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse { - if checkError := checkBatchInstance(req); checkError != nil { - return checkError - } - return batchOperateInstances(ctx, req, s.DeleteInstance) } // DeleteInstance 删除单个服务实例 func (s *Server) DeleteInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) - - // 参数检查 - instanceID, checkError := checkReviseInstance(req) - if checkError != nil { - return checkError - } - // 限制instance频繁反注册 - if ok := s.allowInstanceAccess(instanceID); !ok { - log.Error("delete instance is not allow access", utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) - return api.NewInstanceResponse(apimodel.Code_InstanceTooManyRequests, req) - } - ins := *req // 防止污染外部的req - ins.Id = utils.NewStringValue(instanceID) ins.ServiceToken = utils.NewStringValue(parseInstanceReqToken(ctx, req)) return s.deleteInstance(ctx, req, &ins) } @@ -263,11 +225,9 @@ func (s *Server) serialDeleteInstance( ctx context.Context, req *apiservice.Instance, ins *apiservice.Instance) *apiservice.Response { start := time.Now() // 检查服务实例是否存在 - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) instance, err := s.storage.GetInstance(ins.GetId().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } if instance == nil { @@ -282,13 +242,13 @@ func (s *Server) serialDeleteInstance( // 存储层操作 if err := s.storage.DeleteInstance(instance.ID()); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperInstanceStoreResponse(req, err) } msg := fmt.Sprintf("delete instance: id=%v, namespace=%v, service=%v, host=%v, port=%v", instance.ID(), service.Namespace, service.Name, instance.Host(), instance.Port()) - log.Info(msg, utils.ZapRequestID(rid), utils.ZapPlatformID(pid), zap.Duration("cost", time.Since(start))) + log.Info(msg, utils.RequestID(ctx), zap.Duration("cost", time.Since(start))) s.RecordHistory(ctx, instanceRecordEntry(ctx, req, service, instance, model.ODelete)) event := &model.InstanceEvent{ Id: instance.ID(), @@ -309,8 +269,6 @@ func (s *Server) serialDeleteInstance( func (s *Server) asyncDeleteInstance( ctx context.Context, req *apiservice.Instance, ins *apiservice.Instance) *apiservice.Response { start := time.Now() - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) allowAsyncRegis, _ := ctx.Value(utils.ContextOpenAsyncRegis).(bool) future := s.bc.AsyncDeleteInstance(ins, !allowAsyncRegis) if err := future.Wait(); err != nil { @@ -318,7 +276,7 @@ func (s *Server) asyncDeleteInstance( if future.Code() == apimodel.Code_NotFoundResource { return api.NewInstanceResponse(apimodel.Code_ExecuteSuccess, req) } - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewInstanceResponse(future.Code(), req) } instance := future.Instance() @@ -326,7 +284,7 @@ func (s *Server) asyncDeleteInstance( // 打印本地日志与操作记录 msg := fmt.Sprintf("delete instance: id=%v, namespace=%v, service=%v, host=%v, port=%v", instance.ID(), instance.Namespace(), instance.Service(), instance.Host(), instance.Port()) - log.Info(msg, utils.ZapRequestID(rid), utils.ZapPlatformID(pid), zap.Duration("cost", time.Since(start))) + log.Info(msg, utils.RequestID(ctx), zap.Duration("cost", time.Since(start))) service := &model.Service{Name: instance.Service(), Namespace: instance.Namespace()} s.RecordHistory(ctx, instanceRecordEntry(ctx, req, service, instance, model.ODelete)) event := &model.InstanceEvent{ @@ -346,23 +304,11 @@ func (s *Server) asyncDeleteInstance( // DeleteInstancesByHost 根据host批量删除服务实例 func (s *Server) DeleteInstancesByHost( ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse { - if checkError := checkBatchInstance(req); checkError != nil { - return checkError - } - return batchOperateInstances(ctx, req, s.DeleteInstanceByHost) } // DeleteInstanceByHost 根据host删除服务实例 func (s *Server) DeleteInstanceByHost(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - - // 参数校验 - if err := checkInstanceByHost(req); err != nil { - return err - } - // 获取实例 instances, service, err := s.getInstancesMainByService(ctx, req) if err != nil { @@ -379,14 +325,14 @@ func (s *Server) DeleteInstanceByHost(ctx context.Context, req *apiservice.Insta } if err := s.storage.BatchDeleteInstances(ids); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperInstanceStoreResponse(req, err) } for _, instance := range instances { msg := fmt.Sprintf("delete instance: id=%v, namespace=%v, service=%v, host=%v, port=%v", instance.ID(), service.Namespace, service.Name, instance.Host(), instance.Port()) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, instanceRecordEntry(ctx, req, service, instance, model.ODelete)) s.sendDiscoverEvent(model.InstanceEvent{ Id: instance.ID(), @@ -402,10 +348,6 @@ func (s *Server) DeleteInstanceByHost(ctx context.Context, req *apiservice.Insta // UpdateInstances 批量修改服务实例 func (s *Server) UpdateInstances(ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse { - if checkError := checkBatchInstance(req); checkError != nil { - return checkError - } - return batchOperateInstances(ctx, req, s.UpdateInstance) } @@ -415,32 +357,26 @@ func (s *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) * if preErr != nil { return preErr } - if err := checkMetadata(req.GetMetadata()); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req) - } - // 修改 - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - log.Info(fmt.Sprintf("old instance: %+v", instance), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(fmt.Sprintf("old instance: %+v", instance), utils.RequestID(ctx)) var eventTypes map[model.InstanceEventType]bool var needUpdate bool // 存储层操作 if needUpdate, eventTypes = s.updateInstanceAttribute(req, instance); !needUpdate { log.Info("update instance no data change, no need update", - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID), zap.String("instance", req.String())) + utils.RequestID(ctx), zap.String("instance", req.String())) return api.NewInstanceResponse(apimodel.Code_NoNeedUpdate, req) } if err := s.storage.UpdateInstance(instance); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperInstanceStoreResponse(req, err) } msg := fmt.Sprintf("update instance: id=%v, namespace=%v, service=%v, host=%v, port=%v, healthy = %v", instance.ID(), service.Namespace, service.Name, instance.Host(), instance.Port(), instance.Healthy()) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, instanceRecordEntry(ctx, req, service, instance, model.OUpdate)) for eventType := range eventTypes { @@ -467,23 +403,12 @@ func (s *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) * // @note 必填参数为service+namespace+host func (s *Server) UpdateInstancesIsolate( ctx context.Context, req []*apiservice.Instance) *apiservice.BatchWriteResponse { - if checkError := checkBatchInstance(req); checkError != nil { - return checkError - } - return batchOperateInstances(ctx, req, s.UpdateInstanceIsolate) } // UpdateInstanceIsolate 修改服务实例隔离状态 // @note 必填参数为service+namespace+ip func (s *Server) UpdateInstanceIsolate(ctx context.Context, req *apiservice.Instance) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - - // 参数校验 - if err := checkInstanceByHost(req); err != nil { - return err - } if req.GetIsolate() == nil { return api.NewInstanceResponse(apimodel.Code_InvalidInstanceIsolate, req) } @@ -522,14 +447,14 @@ func (s *Server) UpdateInstanceIsolate(ctx context.Context, req *apiservice.Inst } if err := s.storage.BatchSetInstanceIsolate(ids, isolate, utils.NewUUID()); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperInstanceStoreResponse(req, err) } for _, instance := range instances { msg := fmt.Sprintf("update instance: id=%v, namespace=%v, service=%v, host=%v, port=%v, isolate=%v", instance.ID(), service.Namespace, service.Name, instance.Host(), instance.Port(), instance.Isolate()) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, instanceRecordEntry(ctx, req, service, instance, model.OUpdateIsolate)) // 比对下更新前后的 isolate 状态 @@ -556,38 +481,16 @@ func (s *Server) UpdateInstanceIsolate(ctx context.Context, req *apiservice.Inst return api.NewInstanceResponse(apimodel.Code_ExecuteSuccess, req) } -/** - * @brief 根据ip隔离和删除服务实例的参数检查 - */ -func checkInstanceByHost(req *apiservice.Instance) *apiservice.Response { - if req == nil { - return api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) - } - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidServiceName, req) - } - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidNamespaceName, req) - } - if err := checkInstanceHost(req.GetHost()); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidInstanceHost, req) - } - return nil -} - /** * @brief 根据服务和host获取服务实例 */ func (s *Server) getInstancesMainByService(ctx context.Context, req *apiservice.Instance) ( []*model.Instance, *model.Service, *apiservice.Response) { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - // 检查服务 // 这里获取的是源服务的token。如果是别名,service=nil service, err := s.storage.GetSourceServiceToken(req.GetService().GetValue(), req.GetNamespace().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } if service == nil { @@ -597,7 +500,7 @@ func (s *Server) getInstancesMainByService(ctx context.Context, req *apiservice. // 获取服务实例 instances, err := s.storage.GetInstancesMainByService(service.ID, req.GetHost().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } return instances, service, nil @@ -766,10 +669,7 @@ func (s *Server) GetInstances(ctx context.Context, query map[string]string) *api return batchErr } // 分页数据 - offset, limit, err := utils.ParseOffsetAndLimit(filters) - if err != nil { - return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) - } + offset, limit, _ := utils.ParseOffsetAndLimit(filters) total, instances, err := s.Cache().Instance().QueryInstances(filters, metaFilter, offset, limit) if err != nil { @@ -939,18 +839,10 @@ func (s *Server) GetInstancesCount(ctx context.Context) *apiservice.BatchQueryRe // update/delete instance前置条件 func (s *Server) execInstancePreStep(ctx context.Context, req *apiservice.Instance) ( *model.Service, *model.Instance, *apiservice.Response) { - rid := utils.ParseRequestID(ctx) - - // 参数检查 - instanceID, checkError := checkReviseInstance(req) - if checkError != nil { - return nil, nil, checkError - } - // 检查服务实例是否存在 - instance, err := s.storage.GetInstance(instanceID) + instance, err := s.storage.GetInstance(req.GetId().GetValue()) if err != nil { - log.Error("[Instance] get instance from store", utils.ZapRequestID(rid), utils.ZapInstanceID(instanceID), + log.Error("[Instance] get instance from store", utils.RequestID(ctx), utils.ZapInstanceID(req.GetId().GetValue()), zap.Error(err)) return nil, nil, api.NewInstanceResponse(commonstore.StoreCode2APICode(err), req) } @@ -1148,84 +1040,6 @@ func (s *Server) loadServiceByID(svcID string) (*model.Service, error) { return svc, nil } -/* - * @brief 检查批量请求 - */ -func checkBatchInstance(req []*apiservice.Instance) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -/* - * @brief 检查创建服务实例请求参数 - */ -func checkCreateInstance(req *apiservice.Instance) (string, *apiservice.Response) { - if req == nil { - return "", api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) - } - - if err := checkMetadata(req.GetMetadata()); err != nil { - return "", api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req) - } - - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbInstanceFieldLen(req) - if notOk { - return "", err - } - - return utils.CheckInstanceTetrad(req) -} - -/* - * @brief 检查删除/修改服务实例请求参数 - */ -func checkReviseInstance(req *apiservice.Instance) (string, *apiservice.Response) { - if req == nil { - return "", api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) - } - - if req.GetId() != nil { - if req.GetId().GetValue() == "" { - return "", api.NewInstanceResponse(apimodel.Code_InvalidInstanceID, req) - } - return req.GetId().GetValue(), nil - } - - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbInstanceFieldLen(req) - if notOk { - return "", err - } - - return utils.CheckInstanceTetrad(req) -} - -/* - * @brief 检查心跳实例请求参数 - * 检查是否存在token,以及 id或者四元组 - * 注意:心跳上报只允许从client上报,因此token只会存在req中 - */ -func checkHeartbeatInstance(req *apiservice.Instance) (string, *apiservice.Response) { - if req == nil { - return "", api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) - } - if req.GetId() != nil { - if req.GetId().GetValue() == "" { - return "", api.NewInstanceResponse(apimodel.Code_InvalidInstanceID, req) - } - return req.GetId().GetValue(), nil - } - return utils.CheckInstanceTetrad(req) -} - // 获取instance请求的token信息 func parseInstanceReqToken(ctx context.Context, req *apiservice.Instance) string { if reqToken := req.GetServiceToken().GetValue(); reqToken != "" { @@ -1237,29 +1051,14 @@ func parseInstanceReqToken(ctx context.Context, req *apiservice.Instance) string // 实例查询前置处理 func preGetInstances(query map[string]string) (map[string]string, map[string]string, *apiservice.BatchQueryResponse) { - // 不允许全量查询服务实例 - if len(query) == 0 { - return nil, nil, api.NewBatchQueryResponse(apimodel.Code_EmptyQueryParameter) - } - var metaFilter map[string]string metaKey, metaKeyAvail := query["keys"] - metaValue, metaValueAvail := query["values"] - if metaKeyAvail != metaValueAvail { - return nil, nil, api.NewBatchQueryResponseWithMsg( - apimodel.Code_InvalidQueryInsParameter, "instance metadata key and value must be both provided") - } if metaKeyAvail { metaFilter = map[string]string{} keys := strings.Split(metaKey, ",") - values := strings.Split(metaValue, ",") - if len(keys) == len(values) { - for i := range keys { - metaFilter[keys[i]] = values[i] - } - } else { - return nil, nil, api.NewBatchQueryResponseWithMsg( - apimodel.Code_InvalidQueryInsParameter, "instance metadata key and value length are different") + values := strings.Split(query["values"], ",") + for i := range keys { + metaFilter[keys[i]] = values[i] } } @@ -1272,17 +1071,6 @@ func preGetInstances(query map[string]string) (map[string]string, map[string]str filters := make(map[string]string) for key, value := range query { - if _, ok := InstanceFilterAttributes[key]; !ok { - log.Errorf("[Server][Instance][Query] attribute(%s) is not allowed", key) - return nil, metaFilter, api.NewBatchQueryResponseWithMsg( - apimodel.Code_InvalidParameter, key+" is not allowed") - } - - if value == "" { - log.Errorf("[Server][Instance][Query] attribute(%s: %s) is not allowed empty", key, value) - return nil, metaFilter, api.NewBatchQueryResponseWithMsg( - apimodel.Code_InvalidParameter, "the value for "+key+" is empty") - } if attr, ok := InsFilter2toreAttr[key]; ok { key = attr } @@ -1345,39 +1133,6 @@ func instanceRecordEntry(ctx context.Context, req *apiservice.Instance, service return entry } -// CheckDbInstanceFieldLen 检查DB中service表对应的入参字段合法性 -func CheckDbInstanceFieldLen(req *apiservice.Instance) (*apiservice.Response, bool) { - if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidServiceName, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidNamespaceName, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetHost(), MaxDbInsHostLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidInstanceHost, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetProtocol(), MaxDbInsProtocolLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidInstanceProtocol, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetVersion(), MaxDbInsVersionLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidInstanceVersion, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetLogicSet(), MaxDbInsLogicSetLength); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidInstanceLogicSet, req), true - } - if err := utils.CheckDbMetaDataFieldLen(req.GetMetadata()); err != nil { - return api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req), true - } - if req.GetPort().GetValue() > 65535 { - return api.NewInstanceResponse(apimodel.Code_InvalidInstancePort, req), true - } - - if req.GetWeight().GetValue() > 65535 { - return api.NewInstanceResponse(apimodel.Code_InvalidParameter, req), true - } - return nil, false -} - type InstanceChain interface { // AfterUpdate . AfterUpdate(ctx context.Context, instances ...*model.Instance) diff --git a/service/instance_check_test.go b/service/instance_check_test.go index c2e321f1c..a6cf62b7b 100644 --- a/service/instance_check_test.go +++ b/service/instance_check_test.go @@ -23,6 +23,7 @@ import ( "testing" "time" + "github.com/polarismesh/polaris/cache" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "github.com/stretchr/testify/assert" @@ -67,7 +68,7 @@ func TestInstanceCheck(t *testing.T) { time.Sleep(1 * time.Second) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() instance1 := discoverSuit.DiscoverServer().Cache().Instance().GetInstance(instanceId1) assert.NotNil(t, instance1) assert.Equal(t, true, instance1.Proto.GetHealthy().GetValue()) diff --git a/service/instance_test.go b/service/instance_test.go index f1ca032a2..a5d25008b 100644 --- a/service/instance_test.go +++ b/service/instance_test.go @@ -35,6 +35,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/polarismesh/polaris/cache" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" @@ -113,7 +114,7 @@ func TestCreateInstance(t *testing.T) { t.Fatalf("error: %+v", resp) } // 强制先update一次,规避上一次的数据查询结果 - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{}) }) @@ -159,7 +160,7 @@ func TestCreateInstance(t *testing.T) { t.Fatalf("error: %+v", resp) } // 强制先update一次,规避上一次的数据查询结果 - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() getResp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{"host": instanceReq.GetHost().GetValue()}) assert.True(t, getResp.GetCode().GetValue() == api.ExecuteSuccess) t.Logf("%+v", getResp) @@ -235,7 +236,7 @@ func TestCreateInstanceWithNoService(t *testing.T) { }() // 等待一段时间的刷新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resps := discoverSuit.DiscoverServer().CreateInstances(discoverSuit.DefaultCtx, reqs) if respSuccess(resps) { @@ -264,7 +265,7 @@ func TestCreateInstance2(t *testing.T) { serviceResps = append(serviceResps, serviceResp) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() total := 20 var wg sync.WaitGroup start := time.Now() @@ -332,7 +333,7 @@ func TestUpdateInstanceManyTimes(t *testing.T) { ret := &apiservice.Instance{} proto.Unmarshal(marshalVal, ret) - ret.Weight.Value = uint32(rand.Int() % 32767) + ret.Weight = wrapperspb.UInt32(uint32(rand.Int() % 32767)) if updateResp := discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{instanceReq}); !respSuccess(updateResp) { errs <- fmt.Errorf("error: %+v", updateResp) return @@ -388,7 +389,7 @@ func TestGetInstancesById(t *testing.T) { t.Run("根据精准匹配ID进行获取实例", func(t *testing.T) { instId := fmt.Sprintf("%s%d", idPrefix, 0) // 强制先update一次,规避上一次的数据查询结果 - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{"id": instId}) assert.True(t, respSuccess(out)) assert.Equal(t, 1, len(out.GetInstances())) @@ -402,7 +403,7 @@ func TestGetInstancesById(t *testing.T) { t.Run("根据前缀匹配ID进行获取实例", func(t *testing.T) { instId := fmt.Sprintf("%s%s", idPrefix, "*") // 强制先update一次,规避上一次的数据查询结果 - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{"id": instId}) assert.True(t, respSuccess(out)) assert.Equal(t, prefixCount, len(out.GetInstances())) @@ -415,7 +416,7 @@ func TestGetInstancesById(t *testing.T) { t.Run("根据后缀匹配ID进行获取实例", func(t *testing.T) { instId := fmt.Sprintf("%s%s", "*", idSuffix) // 强制先update一次,规避上一次的数据查询结果 - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() out := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{"id": instId}) assert.True(t, respSuccess(out)) assert.Equal(t, suffixCount, len(out.GetInstances())) @@ -435,16 +436,16 @@ func TestGetInstances(t *testing.T) { } defer discoverSuit.Destroy() t.Run("可以正常获取到实例信息", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 _, serviceResp := discoverSuit.createCommonService(t, 320) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() instanceReq, instanceResp := discoverSuit.createCommonInstance(t, serviceResp, 30) defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) // 需要等待一会,等本地缓存更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() req := &apiservice.Service{ Name: utils.NewStringValue(instanceResp.GetService().GetValue()), Namespace: utils.NewStringValue(instanceResp.GetNamespace().GetValue()), @@ -464,7 +465,7 @@ func TestGetInstances(t *testing.T) { t.Logf("pass: %+v", resp.GetInstances()[0]) }) t.Run("注册实例,查询实例列表,实例反注册,revision会改变", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 _, serviceResp := discoverSuit.createCommonService(t, 100) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) @@ -472,7 +473,7 @@ func TestGetInstances(t *testing.T) { defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) // 需要等待一会,等本地缓存更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, serviceResp) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -483,7 +484,7 @@ func TestGetInstances(t *testing.T) { _, instanceResp = discoverSuit.createCommonInstance(t, serviceResp, 100) defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp = discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, serviceResp) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -506,7 +507,7 @@ func TestGetInstances1(t *testing.T) { defer discoverSuit.Destroy() discover := func(t *testing.T, service *apiservice.Service, check func(cnt int) bool) *apiservice.DiscoverResponse { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() time.Sleep(discoverSuit.UpdateCacheInterval()) resp := discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, service) if !respSuccess(resp) { @@ -519,7 +520,7 @@ func TestGetInstances1(t *testing.T) { return resp } t.Run("注册并反注册多个实例,可以正常获取", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 _, serviceResp := discoverSuit.createCommonService(t, 320) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) @@ -543,7 +544,7 @@ func TestGetInstances1(t *testing.T) { }) }) t.Run("传递revision, revision有变化则有数据,否则无数据返回", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() // 为了防止影响,每个函数需要把缓存的内容清空 _, serviceResp := discoverSuit.createCommonService(t, 100) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) for i := 0; i < 5; i++ { @@ -620,7 +621,7 @@ func TestRemoveInstance(t *testing.T) { wCtx, _ := eventhub.SubscribeWithFunc(eventhub.CacheInstanceEventTopic, func(ctx context.Context, any2 any) error { time.Sleep(3 * time.Second) event := any2.(*eventhub.CacheInstanceEvent) - t.Logf("receive instance change event : %#v", event) + t.Logf("receive instance change event : %s", utils.MustJson(event)) switch event.EventType { case eventhub.EventCreated: waitCreateOnce.Do(func() { @@ -640,7 +641,7 @@ func TestRemoveInstance(t *testing.T) { _, instanceResp := discoverSuit.createCommonInstance(t, serviceResp, 1111) defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() waitCreate.Wait() discoverSuit.HeartBeat(t, serviceResp, instanceResp.GetId().GetValue()) resp := discoverSuit.GetLastHeartBeat(t, serviceResp, instanceResp.GetId().GetValue()) @@ -648,8 +649,9 @@ func TestRemoveInstance(t *testing.T) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) } + t.Logf("begin deregister instance : %s", instanceResp.GetId().GetValue()) discoverSuit.removeCommonInstance(t, serviceResp, instanceResp.GetId().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() waitRemove.Wait() resp = discoverSuit.GetLastHeartBeat(t, serviceResp, instanceResp.GetId().GetValue()) if !respNotFound(resp) { @@ -680,7 +682,7 @@ func TestListInstances(t *testing.T) { query["port"] = strconv.FormatUint(uint64(instanceReq.GetPort().GetValue()), 10) // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -695,7 +697,7 @@ func TestListInstances(t *testing.T) { _, serviceResp := discoverSuit.createCommonService(t, 115) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() total := 50 var responses []*apiservice.Instance for i := 0; i < total; i++ { @@ -712,7 +714,7 @@ func TestListInstances(t *testing.T) { query := map[string]string{"offset": "10", "limit": "20", "host": "127.0.0.1"} // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -723,7 +725,7 @@ func TestListInstances(t *testing.T) { query = map[string]string{"offset": "10", "limit": "20"} // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp = discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -751,7 +753,7 @@ func TestListInstances(t *testing.T) { query := map[string]string{"offset": "0", "limit": "200"} // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -764,7 +766,7 @@ func TestListInstances(t *testing.T) { _, serviceResp := discoverSuit.createCommonService(t, 200) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() total := 10 instance := new(apiservice.Instance) for i := 0; i < total; i++ { @@ -778,7 +780,7 @@ func TestListInstances(t *testing.T) { query := map[string]string{"limit": "20", "host": host, "port": port} // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -826,7 +828,7 @@ func TestListInstances1(t *testing.T) { } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) checkAmountAndSize(t, resp, total, 100) }) @@ -856,7 +858,7 @@ func TestListInstances1(t *testing.T) { } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) checkAmountAndSize(t, resp, total/2, total/2) @@ -881,7 +883,7 @@ func TestListInstances1(t *testing.T) { "healthy": "false", } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() checkAmountAndSize(t, discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query), 1, 1) @@ -928,7 +930,7 @@ func TestListInstances1(t *testing.T) { "values": "internal-personal-xxx_10", } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() checkAmountAndSize(t, discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query), 1, 1) @@ -982,7 +984,7 @@ func TestListInstances1(t *testing.T) { "keys": "internal-personal-xxx", } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if resp.GetCode().GetValue() != api.InvalidQueryInsParameter { @@ -1035,7 +1037,7 @@ func TestListInstances2(t *testing.T) { } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) checkAmountAndSize(t, resp, 1, 1) @@ -1062,7 +1064,7 @@ func TestListInstances2(t *testing.T) { } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) checkAmountAndSize(t, resp, 1, 1) @@ -1097,7 +1099,7 @@ func TestListInstances2(t *testing.T) { "values": "1111", } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) checkAmountAndSize(t, resp, 1, 1) @@ -1156,7 +1158,7 @@ func TestInstancesContainLocation(t *testing.T) { defer discoverSuit.cleanInstance(resp.Responses[0].GetInstance().GetId().GetValue()) // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() getResp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{ "service": instance.GetService().GetValue(), "namespace": instance.GetNamespace().GetValue(), }) @@ -1170,7 +1172,7 @@ func TestInstancesContainLocation(t *testing.T) { t.Logf("%v", getInstances[0]) locationCheck(instance.GetLocation(), getInstances[0].GetLocation()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() discoverResp := discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, service) if len(discoverResp.GetInstances()) != 1 { t.Fatalf("error: %d", len(discoverResp.GetInstances())) @@ -1225,7 +1227,7 @@ func TestUpdateInstance(t *testing.T) { "port": strconv.FormatUint(uint64(instanceReq.GetPort().GetValue()), 10), } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -1352,7 +1354,7 @@ func TestUpdateIsolate(t *testing.T) { } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() // 检查隔离状态和revision是否改变 for i := 0; i < instanceNum/portNum; i++ { filter := map[string]string{ @@ -1518,7 +1520,7 @@ func TestUpdateHealthCheck(t *testing.T) { "port": strconv.FormatUint(uint64(req.GetPort().GetValue()), 10), } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -1606,7 +1608,7 @@ func TestDeleteInstance(t *testing.T) { getInstance := func(t *testing.T, s *apiservice.Service, expect int) []*apiservice.Instance { filters := map[string]string{"service": s.GetName().GetValue(), "namespace": s.GetNamespace().GetValue()} // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() getResp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, filters) if !respSuccess(getResp) { t.Fatalf("error") @@ -1817,7 +1819,7 @@ func TestBatchDeleteInstances(t *testing.T) { t.Logf("%+v", out) } // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resps := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{ "service": service.GetName().GetValue(), "namespace": service.GetNamespace().GetValue(), @@ -1874,16 +1876,16 @@ func TestInstanceResponse(t *testing.T) { t.Run("删除实例,返回的信息包括req,不增加信息", func(t *testing.T) { req, resp := create() defer discoverSuit.cleanInstance(resp.GetId().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resps := discoverSuit.DiscoverServer().DeleteInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{req}) if !respSuccess(resps) { t.Fatalf("error: %+v", resps) } respIns := resps.GetResponses()[0].GetInstance() - if respIns.GetId().GetValue() != "" || respIns.GetService() != req.GetService() || - respIns.GetNamespace() != req.GetNamespace() || respIns.GetHost() != req.GetHost() || - respIns.GetPort() != req.GetPort() || respIns.GetServiceToken() != req.GetServiceToken() { - t.Fatalf("error") + if respIns.GetService().GetValue() != req.GetService().GetValue() || + respIns.GetNamespace().GetValue() != req.GetNamespace().GetValue() || respIns.GetHost().GetValue() != req.GetHost().GetValue() || + respIns.GetPort().GetValue() != req.GetPort().GetValue() || respIns.GetServiceToken().GetValue() != req.GetServiceToken().GetValue() { + t.Fatalf("error; \n%s, \n%s", utils.MustJson(req), utils.MustJson(respIns)) } t.Logf("pass") }) @@ -1897,7 +1899,7 @@ func TestCreateInstancesBadCase2(t *testing.T) { t.Fatal(err) } defer discoverSuit.Destroy() - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() _, service := discoverSuit.createCommonService(t, 123) defer discoverSuit.cleanServiceName(service.GetName().GetValue(), service.GetNamespace().GetValue()) @@ -2073,7 +2075,7 @@ func TestUpdateInstancesFiled(t *testing.T) { instanceReq.EnableHealthCheck = utils.NewBoolValue(false) So(discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{instanceReq}).GetCode().GetValue(), ShouldEqual, api.ExecuteSuccess) // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() newInstanceResp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{ "service": serviceResp.GetName().GetValue(), "namespace": serviceResp.GetNamespace().GetValue(), @@ -2113,7 +2115,7 @@ func TestUpdateInstancesFiled(t *testing.T) { So(discoverSuit.DiscoverServer().UpdateInstances(discoverSuit.DefaultCtx, []*apiservice.Instance{instanceReq}).GetCode().GetValue(), ShouldEqual, api.ExecuteSuccess) // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() newInstanceResp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, map[string]string{ "service": serviceResp.GetName().GetValue(), "namespace": serviceResp.GetNamespace().GetValue(), @@ -2130,7 +2132,7 @@ func (d *DiscoverTestSuit) getInstancesWithService(t *testing.T, name string, na "namespace": namespace, } // 强制先update一次,规避上一次的数据查询结果 - _ = d.DiscoverServer().Cache().TestUpdate() + _ = d.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := d.DiscoverServer().GetInstances(d.DefaultCtx, query) if !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) @@ -2280,7 +2282,7 @@ func TestCheckInstanceParam(t *testing.T) { defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) // 强制先update一次,规避上一次的数据查询结果 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() t.Run("都不传", func(t *testing.T) { resp := discoverSuit.DiscoverServer().GetInstances(discoverSuit.DefaultCtx, make(map[string]string)) @@ -2504,7 +2506,7 @@ func Test_HealthCheckInstanceMetadata(t *testing.T) { err := future.Wait() assert.NoError(t, err) - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() ins1Cache := discoverSuit.DiscoverServer().Cache().Instance().GetInstance(ins1.GetId().GetValue()) assert.NotNil(t, ins1Cache, "ins1Cache is nil") val, ok := ins1Cache.Metadata()[model.MetadataInstanceLastHeartbeatTime] @@ -2518,7 +2520,7 @@ func Test_HealthCheckInstanceMetadata(t *testing.T) { err := future.Wait() assert.NoError(t, err) - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() ins1Cache := discoverSuit.DiscoverServer().Cache().Instance().GetInstance(ins1.GetId().GetValue()) assert.NotNil(t, ins1Cache, "ins1Cache is nil") _, ok := ins1Cache.Metadata()[model.MetadataInstanceLastHeartbeatTime] @@ -2565,7 +2567,7 @@ func Test_OperateInstanceMetadata(t *testing.T) { }) assert.NoError(t, err) - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() ins1Cache := discoverSuit.DiscoverServer().Cache().Instance().GetInstance(ins1.GetId().GetValue()) assert.NotNil(t, ins1Cache, "ins1Cache is nil") val, ok := ins1Cache.Metadata()["ins1_mock_key"] @@ -2603,7 +2605,7 @@ func Test_OperateInstanceMetadata(t *testing.T) { }) assert.NoError(t, err) - discoverSuit.DiscoverServer().Cache().TestUpdate() + discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() ins1Cache := discoverSuit.DiscoverServer().Cache().Instance().GetInstance(ins1.GetId().GetValue()) assert.NotNil(t, ins1Cache, "ins1Cache is nil") _, ok := ins1Cache.Metadata()["ins1_mock_key"] diff --git a/service/interceptor/auth/server_authability.go b/service/interceptor/auth/server_authability.go index 8c0de800a..b70909509 100644 --- a/service/interceptor/auth/server_authability.go +++ b/service/interceptor/auth/server_authability.go @@ -29,7 +29,7 @@ import ( "go.uber.org/zap" "github.com/polarismesh/polaris/auth" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/service" @@ -60,7 +60,7 @@ func NewServerAuthAbility(nextSvr service.DiscoverServer, } // Cache Get cache management -func (svr *ServerAuthAbility) Cache() *cache.CacheManager { +func (svr *ServerAuthAbility) Cache() cachetypes.CacheManager { return svr.nextSvr.Cache() } diff --git a/service/interceptor/paramcheck/check.go b/service/interceptor/paramcheck/check.go deleted file mode 100644 index 90a6e39a7..000000000 --- a/service/interceptor/paramcheck/check.go +++ /dev/null @@ -1,406 +0,0 @@ -/** - * Tencent is pleased to support the open source community by making Polaris available. - * - * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the BSD 3-Clause License (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://opensource.org/licenses/BSD-3-Clause - * - * Unless required by applicable law or agreed to in writing, software distributed - * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR - * CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package paramcheck - -import ( - "context" - - "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" - "github.com/polarismesh/specification/source/go/api/v1/service_manage" - "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" - - "github.com/polarismesh/polaris/common/api/l5" - "github.com/polarismesh/polaris/common/model" -) - -// AppendServiceContractInterfaces implements service.DiscoverServer. -func (svr *Server) AppendServiceContractInterfaces(ctx context.Context, - contract *service_manage.ServiceContract, - source service_manage.InterfaceDescriptor_Source) *service_manage.Response { - return svr.nextSvr.AppendServiceContractInterfaces(ctx, contract, source) -} - -// CreateCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) CreateCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateCircuitBreakerRules(ctx, request) -} - -// CreateCircuitBreakerVersions implements service.DiscoverServer. -func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, - req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateCircuitBreakerVersions(ctx, req) -} - -// CreateCircuitBreakers implements service.DiscoverServer. -func (svr *Server) CreateCircuitBreakers(ctx context.Context, - req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateCircuitBreakers(ctx, req) -} - -// CreateFaultDetectRules implements service.DiscoverServer. -func (svr *Server) CreateFaultDetectRules(ctx context.Context, - request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateFaultDetectRules(ctx, request) -} - -// CreateInstances implements service.DiscoverServer. -func (svr *Server) CreateInstances(ctx context.Context, - reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateInstances(ctx, reqs) -} - -// CreateRateLimits implements service.DiscoverServer. -func (svr *Server) CreateRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateRateLimits(ctx, request) -} - -// CreateRoutingConfigs implements service.DiscoverServer. -func (svr *Server) CreateRoutingConfigs(ctx context.Context, - req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateRoutingConfigs(ctx, req) -} - -// CreateRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateRoutingConfigsV2(ctx, req) -} - -// CreateServiceAlias implements service.DiscoverServer. -func (svr *Server) CreateServiceAlias(ctx context.Context, - req *service_manage.ServiceAlias) *service_manage.Response { - return svr.nextSvr.CreateServiceAlias(ctx, req) -} - -// CreateServiceContractInterfaces implements service.DiscoverServer. -func (svr *Server) CreateServiceContractInterfaces(ctx context.Context, - contract *service_manage.ServiceContract, source service_manage.InterfaceDescriptor_Source) *service_manage.Response { - return svr.nextSvr.CreateServiceContractInterfaces(ctx, contract, source) -} - -// CreateServiceContracts implements service.DiscoverServer. -func (svr *Server) CreateServiceContracts(ctx context.Context, - req []*service_manage.ServiceContract) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateServiceContracts(ctx, req) -} - -// CreateServices implements service.DiscoverServer. -func (svr *Server) CreateServices(ctx context.Context, - req []*service_manage.Service) *service_manage.BatchWriteResponse { - return svr.nextSvr.CreateServices(ctx, req) -} - -// DeleteCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) DeleteCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteCircuitBreakerRules(ctx, request) -} - -// DeleteCircuitBreakers implements service.DiscoverServer. -func (svr *Server) DeleteCircuitBreakers(ctx context.Context, - req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteCircuitBreakers(ctx, req) -} - -// DeleteFaultDetectRules implements service.DiscoverServer. -func (svr *Server) DeleteFaultDetectRules(ctx context.Context, - request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteFaultDetectRules(ctx, request) -} - -// DeleteInstances implements service.DiscoverServer. -func (svr *Server) DeleteInstances(ctx context.Context, - req []*service_manage.Instance) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteInstances(ctx, req) -} - -// DeleteInstancesByHost implements service.DiscoverServer. -func (svr *Server) DeleteInstancesByHost(ctx context.Context, - req []*service_manage.Instance) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteInstancesByHost(ctx, req) -} - -// DeleteRateLimits implements service.DiscoverServer. -func (svr *Server) DeleteRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteRateLimits(ctx, request) -} - -// DeleteRoutingConfigs implements service.DiscoverServer. -func (svr *Server) DeleteRoutingConfigs(ctx context.Context, - req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteRoutingConfigs(ctx, req) -} - -// DeleteRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteRoutingConfigsV2(ctx, req) -} - -// DeleteServiceAliases implements service.DiscoverServer. -func (svr *Server) DeleteServiceAliases(ctx context.Context, - req []*service_manage.ServiceAlias) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteServiceAliases(ctx, req) -} - -// DeleteServiceContractInterfaces implements service.DiscoverServer. -func (svr *Server) DeleteServiceContractInterfaces(ctx context.Context, - contract *service_manage.ServiceContract) *service_manage.Response { - return svr.nextSvr.DeleteServiceContractInterfaces(ctx, contract) -} - -// DeleteServiceContracts implements service.DiscoverServer. -func (svr *Server) DeleteServiceContracts(ctx context.Context, - req []*service_manage.ServiceContract) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteServiceContracts(ctx, req) -} - -// DeleteServices implements service.DiscoverServer. -func (svr *Server) DeleteServices(ctx context.Context, - req []*service_manage.Service) *service_manage.BatchWriteResponse { - return svr.nextSvr.DeleteServices(ctx, req) -} - -// EnableCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) EnableCircuitBreakerRules(ctx context.Context, - request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.EnableCircuitBreakerRules(ctx, request) -} - -// EnableRateLimits implements service.DiscoverServer. -func (svr *Server) EnableRateLimits(ctx context.Context, - request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.EnableRateLimits(ctx, request) -} - -// EnableRoutings implements service.DiscoverServer. -func (svr *Server) EnableRoutings(ctx context.Context, - req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.EnableRoutings(ctx, req) -} - -// GetAllServices implements service.DiscoverServer. -func (svr *Server) GetAllServices(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetAllServices(ctx, query) -} - -// GetCircuitBreaker implements service.DiscoverServer. -func (svr *Server) GetCircuitBreaker(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetCircuitBreaker(ctx, query) -} - -// GetCircuitBreakerByService implements service.DiscoverServer. -func (svr *Server) GetCircuitBreakerByService(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetCircuitBreakerByService(ctx, query) -} - -// GetCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) GetCircuitBreakerRules(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetCircuitBreakerRules(ctx, query) -} - -// GetCircuitBreakerToken implements service.DiscoverServer. -func (svr *Server) GetCircuitBreakerToken(ctx context.Context, - req *fault_tolerance.CircuitBreaker) *service_manage.Response { - return svr.nextSvr.GetCircuitBreakerToken(ctx, req) -} - -// GetCircuitBreakerVersions implements service.DiscoverServer. -func (svr *Server) GetCircuitBreakerVersions(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetCircuitBreakerVersions(ctx, query) -} - -// GetFaultDetectRules implements service.DiscoverServer. -func (svr *Server) GetFaultDetectRules(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetFaultDetectRules(ctx, query) -} - -// GetInstanceLabels implements service.DiscoverServer. -func (svr *Server) GetInstanceLabels(ctx context.Context, - query map[string]string) *service_manage.Response { - return svr.nextSvr.GetInstanceLabels(ctx, query) -} - -// GetInstances implements service.DiscoverServer. -func (svr *Server) GetInstances(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetInstances(ctx, query) -} - -// GetInstancesCount implements service.DiscoverServer. -func (svr *Server) GetInstancesCount(ctx context.Context) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetInstancesCount(ctx) -} - -// GetMasterCircuitBreakers implements service.DiscoverServer. -func (svr *Server) GetMasterCircuitBreakers(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetMasterCircuitBreakers(ctx, query) -} - -// GetPrometheusTargets implements service.DiscoverServer. -func (svr *Server) GetPrometheusTargets(ctx context.Context, - query map[string]string) *model.PrometheusDiscoveryResponse { - return svr.nextSvr.GetPrometheusTargets(ctx, query) -} - -// GetRateLimits implements service.DiscoverServer. -func (svr *Server) GetRateLimits(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetRateLimits(ctx, query) -} - -// GetReleaseCircuitBreakers implements service.DiscoverServer. -func (svr *Server) GetReleaseCircuitBreakers(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetReleaseCircuitBreakers(ctx, query) -} - -// GetRoutingConfigs implements service.DiscoverServer. -func (svr *Server) GetRoutingConfigs(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetRoutingConfigs(ctx, query) -} - -// GetServiceAliases implements service.DiscoverServer. -func (svr *Server) GetServiceAliases(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServiceAliases(ctx, query) -} - -// GetServiceContractVersions implements service.DiscoverServer. -func (svr *Server) GetServiceContractVersions(ctx context.Context, - filter map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServiceContractVersions(ctx, filter) -} - -// GetServiceContracts implements service.DiscoverServer. -func (svr *Server) GetServiceContracts(ctx context.Context, - query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServiceContracts(ctx, query) -} - -// GetServiceOwner implements service.DiscoverServer. -func (svr *Server) GetServiceOwner(ctx context.Context, - req []*service_manage.Service) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServiceOwner(ctx, req) -} - -// GetServiceToken implements service.DiscoverServer. -func (svr *Server) GetServiceToken(ctx context.Context, req *service_manage.Service) *service_manage.Response { - return svr.nextSvr.GetServiceToken(ctx, req) -} - -// GetServices implements service.DiscoverServer. -func (svr *Server) GetServices(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServices(ctx, query) -} - -// GetServicesCount implements service.DiscoverServer. -func (svr *Server) GetServicesCount(ctx context.Context) *service_manage.BatchQueryResponse { - return svr.nextSvr.GetServicesCount(ctx) -} - -// QueryRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { - return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) -} - -// RegisterByNameCmd implements service.DiscoverServer. -func (svr *Server) RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) { - return svr.nextSvr.RegisterByNameCmd(rbnc) -} - -// ReleaseCircuitBreakers implements service.DiscoverServer. -func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { - return svr.nextSvr.ReleaseCircuitBreakers(ctx, req) -} - -// SyncByAgentCmd implements service.DiscoverServer. -func (svr *Server) SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) { - return svr.nextSvr.SyncByAgentCmd(ctx, sbac) -} - -// UnBindCircuitBreakers implements service.DiscoverServer. -func (svr *Server) UnBindCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { - return svr.nextSvr.UnBindCircuitBreakers(ctx, req) -} - -// UpdateCircuitBreakerRules implements service.DiscoverServer. -func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateCircuitBreakerRules(ctx, request) -} - -// UpdateCircuitBreakers implements service.DiscoverServer. -func (svr *Server) UpdateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateCircuitBreakers(ctx, req) -} - -// UpdateFaultDetectRules implements service.DiscoverServer. -func (svr *Server) UpdateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateFaultDetectRules(ctx, request) -} - -// UpdateInstances implements service.DiscoverServer. -func (svr *Server) UpdateInstances(ctx context.Context, req []*service_manage.Instance) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateInstances(ctx, req) -} - -// UpdateInstancesIsolate implements service.DiscoverServer. -func (svr *Server) UpdateInstancesIsolate(ctx context.Context, req []*service_manage.Instance) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateInstancesIsolate(ctx, req) -} - -// UpdateRateLimits implements service.DiscoverServer. -func (svr *Server) UpdateRateLimits(ctx context.Context, request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateRateLimits(ctx, request) -} - -// UpdateRoutingConfigs implements service.DiscoverServer. -func (svr *Server) UpdateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateRoutingConfigs(ctx, req) -} - -// UpdateRoutingConfigsV2 implements service.DiscoverServer. -func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) -} - -// UpdateServiceAlias implements service.DiscoverServer. -func (svr *Server) UpdateServiceAlias(ctx context.Context, req *service_manage.ServiceAlias) *service_manage.Response { - return svr.nextSvr.UpdateServiceAlias(ctx, req) -} - -// UpdateServiceToken implements service.DiscoverServer. -func (svr *Server) UpdateServiceToken(ctx context.Context, req *service_manage.Service) *service_manage.Response { - return svr.nextSvr.UpdateServiceToken(ctx, req) -} - -// UpdateServices implements service.DiscoverServer. -func (svr *Server) UpdateServices(ctx context.Context, req []*service_manage.Service) *service_manage.BatchWriteResponse { - return svr.nextSvr.UpdateServices(ctx, req) -} diff --git a/service/interceptor/paramcheck/circuit_breaker.go b/service/interceptor/paramcheck/circuit_breaker.go new file mode 100644 index 000000000..d0c42d79d --- /dev/null +++ b/service/interceptor/paramcheck/circuit_breaker.go @@ -0,0 +1,154 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + "github.com/polarismesh/polaris/common/utils" + + apifault "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" +) + +// GetMasterCircuitBreakers implements service.DiscoverServer. +func (svr *Server) GetMasterCircuitBreakers(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetMasterCircuitBreakers(ctx, query) +} + +// GetReleaseCircuitBreakers implements service.DiscoverServer. +func (svr *Server) GetReleaseCircuitBreakers(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetReleaseCircuitBreakers(ctx, query) +} + +// GetCircuitBreaker implements service.DiscoverServer. +func (svr *Server) GetCircuitBreaker(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetCircuitBreaker(ctx, query) +} + +// GetCircuitBreakerByService implements service.DiscoverServer. +func (svr *Server) GetCircuitBreakerByService(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetCircuitBreakerByService(ctx, query) +} + +// DeleteCircuitBreakers implements service.DiscoverServer. +func (svr *Server) DeleteCircuitBreakers(ctx context.Context, + req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteCircuitBreakers(ctx, req) +} + +// GetCircuitBreakerToken implements service.DiscoverServer. +func (svr *Server) GetCircuitBreakerToken(ctx context.Context, + req *fault_tolerance.CircuitBreaker) *service_manage.Response { + return svr.nextSvr.GetCircuitBreakerToken(ctx, req) +} + +// GetCircuitBreakerVersions implements service.DiscoverServer. +func (svr *Server) GetCircuitBreakerVersions(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetCircuitBreakerVersions(ctx, query) +} + +// GetCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) GetCircuitBreakerRules(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetCircuitBreakerRules(ctx, query) +} + +// DeleteCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) DeleteCircuitBreakerRules(ctx context.Context, + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { + return err + } + return svr.nextSvr.DeleteCircuitBreakerRules(ctx, request) +} + +// EnableCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) EnableCircuitBreakerRules(ctx context.Context, + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { + return err + } + return svr.nextSvr.EnableCircuitBreakerRules(ctx, request) +} + +// ReleaseCircuitBreakers implements service.DiscoverServer. +func (svr *Server) ReleaseCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { + return svr.nextSvr.ReleaseCircuitBreakers(ctx, req) +} + +// UnBindCircuitBreakers implements service.DiscoverServer. +func (svr *Server) UnBindCircuitBreakers(ctx context.Context, req []*service_manage.ConfigRelease) *service_manage.BatchWriteResponse { + return svr.nextSvr.UnBindCircuitBreakers(ctx, req) +} + +// UpdateCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) UpdateCircuitBreakerRules(ctx context.Context, request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { + return err + } + return svr.nextSvr.UpdateCircuitBreakerRules(ctx, request) +} + +// CreateCircuitBreakerRules implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakerRules(ctx context.Context, + request []*fault_tolerance.CircuitBreakerRule) *service_manage.BatchWriteResponse { + if err := checkBatchCircuitBreakerRules(request); err != nil { + return err + } + return svr.nextSvr.CreateCircuitBreakerRules(ctx, request) +} + +// CreateCircuitBreakerVersions implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakerVersions(ctx context.Context, + req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateCircuitBreakerVersions(ctx, req) +} + +// CreateCircuitBreakers implements service.DiscoverServer. +func (svr *Server) CreateCircuitBreakers(ctx context.Context, + req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateCircuitBreakers(ctx, req) +} + +// UpdateCircuitBreakers implements service.DiscoverServer. +func (svr *Server) UpdateCircuitBreakers(ctx context.Context, req []*fault_tolerance.CircuitBreaker) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateCircuitBreakers(ctx, req) +} + +func checkBatchCircuitBreakerRules(req []*apifault.CircuitBreakerRule) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + return nil +} diff --git a/service/interceptor/paramcheck/cl5.go b/service/interceptor/paramcheck/cl5.go new file mode 100644 index 000000000..8622c33e3 --- /dev/null +++ b/service/interceptor/paramcheck/cl5.go @@ -0,0 +1,34 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + "github.com/polarismesh/polaris/common/api/l5" +) + +// RegisterByNameCmd implements service.DiscoverServer. +func (svr *Server) RegisterByNameCmd(rbnc *l5.Cl5RegisterByNameCmd) (*l5.Cl5RegisterByNameAckCmd, error) { + return svr.nextSvr.RegisterByNameCmd(rbnc) +} + +// SyncByAgentCmd implements service.DiscoverServer. +func (svr *Server) SyncByAgentCmd(ctx context.Context, sbac *l5.Cl5SyncByAgentCmd) (*l5.Cl5SyncByAgentAckCmd, error) { + return svr.nextSvr.SyncByAgentCmd(ctx, sbac) +} diff --git a/service/interceptor/paramcheck/client.go b/service/interceptor/paramcheck/client.go index c829f6b6c..83c5d7246 100644 --- a/service/interceptor/paramcheck/client.go +++ b/service/interceptor/paramcheck/client.go @@ -25,17 +25,48 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/model" "github.com/polarismesh/polaris/common/utils" "github.com/polarismesh/polaris/service" ) +var ( + clientFilterAttributes = map[string]struct{}{ + "type": {}, + "host": {}, + "limit": {}, + "offset": {}, + "version": {}, + } +) + +// GetPrometheusTargets implements service.DiscoverServer. +func (svr *Server) GetPrometheusTargets(ctx context.Context, + query map[string]string) *model.PrometheusDiscoveryResponse { + return svr.nextSvr.GetPrometheusTargets(ctx, query) +} + // RegisterInstance create one instance by client func (s *Server) RegisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + // 参数检查 + if err := checkMetadata(req.GetMetadata()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req) + } + instanceID, rsp := checkCreateInstance(req) + if rsp != nil { + return rsp + } + req.Id = utils.NewStringValue(instanceID) return s.nextSvr.RegisterInstance(ctx, req) } // DeregisterInstance delete onr instance by client func (s *Server) DeregisterInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + instanceID, resp := checkReviseInstance(req) + if resp != nil { + return resp + } + req.Id = wrapperspb.String(instanceID) return s.nextSvr.DeregisterInstance(ctx, req) } @@ -133,6 +164,15 @@ func (s *Server) GetLaneRuleWithCache(ctx context.Context, req *apiservice.Servi // UpdateInstance update one instance by client func (s *Server) UpdateInstance(ctx context.Context, req *apiservice.Instance) *apiservice.Response { + // 参数检查 + if err := checkMetadata(req.GetMetadata()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req) + } + instanceID, rsp := checkReviseInstance(req) + if rsp != nil { + return rsp + } + req.Id = utils.NewStringValue(instanceID) return s.nextSvr.UpdateInstance(ctx, req) } diff --git a/service/interceptor/paramcheck/fault_detect.go b/service/interceptor/paramcheck/fault_detect.go new file mode 100644 index 000000000..a21a073a5 --- /dev/null +++ b/service/interceptor/paramcheck/fault_detect.go @@ -0,0 +1,48 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + "github.com/polarismesh/specification/source/go/api/v1/fault_tolerance" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" +) + +// DeleteFaultDetectRules implements service.DiscoverServer. +func (svr *Server) DeleteFaultDetectRules(ctx context.Context, + request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteFaultDetectRules(ctx, request) +} + +// GetFaultDetectRules implements service.DiscoverServer. +func (svr *Server) GetFaultDetectRules(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetFaultDetectRules(ctx, query) +} + +// CreateFaultDetectRules implements service.DiscoverServer. +func (svr *Server) CreateFaultDetectRules(ctx context.Context, + request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateFaultDetectRules(ctx, request) +} + +// UpdateFaultDetectRules implements service.DiscoverServer. +func (svr *Server) UpdateFaultDetectRules(ctx context.Context, request []*fault_tolerance.FaultDetectRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateFaultDetectRules(ctx, request) +} diff --git a/service/interceptor/paramcheck/instance.go b/service/interceptor/paramcheck/instance.go new file mode 100644 index 000000000..4a300b25a --- /dev/null +++ b/service/interceptor/paramcheck/instance.go @@ -0,0 +1,389 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/golang/protobuf/ptypes/wrappers" + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "google.golang.org/protobuf/types/known/wrapperspb" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/common/utils" + "github.com/polarismesh/polaris/plugin" +) + +var ( + // InstanceFilterAttributes 查询实例支持的过滤字段 + InstanceFilterAttributes = map[string]bool{ + "id": true, // 实例ID + "service": true, // 服务name + "namespace": true, // 服务namespace + "host": true, + "port": true, + "keys": true, + "values": true, + "protocol": true, + "version": true, + "health_status": true, + "healthy": true, // health_status, healthy都有,以healthy为准 + "isolate": true, + "weight": true, + "logic_set": true, + "cmdb_region": true, + "cmdb_zone": true, + "cmdb_idc": true, + "priority": true, + "offset": true, + "limit": true, + } +) + +// CreateInstances implements service.DiscoverServer. +func (svr *Server) CreateInstances(ctx context.Context, + reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { + if checkError := checkBatchInstance(reqs); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + req := reqs[i] + instanceID, checkError := checkCreateInstance(req) + if checkError != nil { + api.Collect(batchRsp, checkError) + continue + } + // Restricted Instance frequently registered + if ok := svr.allowInstanceAccess(instanceID); !ok { + log.Error("create instance not allowed to access: exceed ratelimit", + utils.RequestID(ctx), utils.ZapInstanceID(instanceID)) + api.Collect(batchRsp, api.NewInstanceResponse(apimodel.Code_InstanceTooManyRequests, req)) + continue + } + req.Id = wrapperspb.String(instanceID) + reqs[i] = req + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + + return svr.nextSvr.CreateInstances(ctx, reqs) +} + +// DeleteInstances implements service.DiscoverServer. +func (svr *Server) DeleteInstances(ctx context.Context, + reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { + if checkError := checkBatchInstance(reqs); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + req := reqs[i] + instanceID, checkError := checkReviseInstance(req) + if checkError != nil { + api.Collect(batchRsp, checkError) + continue + } + // Restricted Instance frequently registered + if ok := svr.allowInstanceAccess(instanceID); !ok { + log.Error("delete instance is not allow access", utils.RequestID(ctx)) + api.Collect(batchRsp, api.NewInstanceResponse(apimodel.Code_InstanceTooManyRequests, req)) + continue + } + req.Id = wrapperspb.String(instanceID) + reqs[i] = req + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteInstances(ctx, reqs) +} + +// DeleteInstancesByHost implements service.DiscoverServer. +func (svr *Server) DeleteInstancesByHost(ctx context.Context, + reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { + if checkError := checkBatchInstance(reqs); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + if err := checkInstanceByHost(reqs[i]); err != nil { + api.Collect(batchRsp, err) + continue + } + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + + return svr.nextSvr.DeleteInstancesByHost(ctx, reqs) +} + +// GetInstanceLabels implements service.DiscoverServer. +func (svr *Server) GetInstanceLabels(ctx context.Context, + query map[string]string) *service_manage.Response { + return svr.nextSvr.GetInstanceLabels(ctx, query) +} + +// GetInstances implements service.DiscoverServer. +func (svr *Server) GetInstances(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + + // 不允许全量查询服务实例 + if len(query) == 0 { + return api.NewBatchQueryResponse(apimodel.Code_EmptyQueryParameter) + } + + var metaFilter map[string]string + metaKey, metaKeyAvail := query["keys"] + metaValue, metaValueAvail := query["values"] + if metaKeyAvail != metaValueAvail { + return api.NewBatchQueryResponseWithMsg( + apimodel.Code_InvalidQueryInsParameter, "instance metadata key and value must be both provided") + } + if metaKeyAvail { + metaFilter = map[string]string{} + keys := strings.Split(metaKey, ",") + values := strings.Split(metaValue, ",") + if len(keys) == len(values) { + for i := range keys { + metaFilter[keys[i]] = values[i] + } + } else { + return api.NewBatchQueryResponseWithMsg( + apimodel.Code_InvalidQueryInsParameter, "instance metadata key and value length are different") + } + } + + // 以healthy为准 + _, lhs := query["health_status"] + _, rhs := query["healthy"] + if lhs && rhs { + delete(query, "health_status") + } + + for key, value := range query { + if _, ok := InstanceFilterAttributes[key]; !ok { + log.Errorf("[Server][Instance][Query] attribute(%s) is not allowed", key) + return api.NewBatchQueryResponseWithMsg( + apimodel.Code_InvalidParameter, key+" is not allowed") + } + + if value == "" { + log.Errorf("[Server][Instance][Query] attribute(%s: %s) is not allowed empty", key, value) + return api.NewBatchQueryResponseWithMsg( + apimodel.Code_InvalidParameter, "the value for "+key+" is empty") + } + } + + offset, limit, err := utils.ParseOffsetAndLimit(query) + if err != nil { + return api.NewBatchQueryResponse(apimodel.Code_InvalidParameter) + } + query["offset"] = strconv.FormatUint(uint64(offset), 10) + query["limit"] = strconv.FormatUint(uint64(limit), 10) + return svr.nextSvr.GetInstances(ctx, query) +} + +// GetInstancesCount implements service.DiscoverServer. +func (svr *Server) GetInstancesCount(ctx context.Context) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetInstancesCount(ctx) +} + +// UpdateInstances implements service.DiscoverServer. +func (svr *Server) UpdateInstances(ctx context.Context, reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { + if checkError := checkBatchInstance(reqs); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + if err := checkMetadata(reqs[i].GetMetadata()); err != nil { + api.Collect(batchRsp, api.NewInstanceResponse(apimodel.Code_InvalidMetadata, reqs[i])) + continue + } + // 参数检查 + instanceID, checkError := checkReviseInstance(reqs[i]) + if checkError != nil { + api.Collect(batchRsp, checkError) + continue + } + reqs[i].Id = wrapperspb.String(instanceID) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateInstances(ctx, reqs) +} + +// UpdateInstancesIsolate implements service.DiscoverServer. +func (svr *Server) UpdateInstancesIsolate(ctx context.Context, reqs []*service_manage.Instance) *service_manage.BatchWriteResponse { + if checkError := checkBatchInstance(reqs); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range reqs { + // 参数校验 + if err := checkInstanceByHost(reqs[i]); err != nil { + api.Collect(batchRsp, err) + continue + } + } + return svr.nextSvr.UpdateInstancesIsolate(ctx, reqs) +} + +/* + * @brief 检查批量请求 + */ +func checkBatchInstance(req []*apiservice.Instance) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +/* + * @brief 检查创建服务实例请求参数 + */ +func checkCreateInstance(req *apiservice.Instance) (string, *apiservice.Response) { + if req == nil { + return "", api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) + } + + if err := checkMetadata(req.GetMetadata()); err != nil { + return "", api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req) + } + + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbInstanceFieldLen(req) + if notOk { + return "", err + } + + return utils.CheckInstanceTetrad(req) +} + +/* + * @brief 检查删除/修改服务实例请求参数 + */ +func checkReviseInstance(req *apiservice.Instance) (string, *apiservice.Response) { + if req == nil { + return "", api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) + } + + if req.GetId() != nil { + if req.GetId().GetValue() == "" { + return "", api.NewInstanceResponse(apimodel.Code_InvalidInstanceID, req) + } + return req.GetId().GetValue(), nil + } + + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbInstanceFieldLen(req) + if notOk { + return "", err + } + + return utils.CheckInstanceTetrad(req) +} + +// CheckDbInstanceFieldLen 检查DB中service表对应的入参字段合法性 +func CheckDbInstanceFieldLen(req *apiservice.Instance) (*apiservice.Response, bool) { + if err := utils.CheckDbStrFieldLen(req.GetService(), utils.MaxDbServiceNameLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidServiceName, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidNamespaceName, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetHost(), utils.MaxDbInsHostLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidInstanceHost, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetProtocol(), utils.MaxDbInsProtocolLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidInstanceProtocol, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetVersion(), utils.MaxDbInsVersionLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidInstanceVersion, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetLogicSet(), utils.MaxDbInsLogicSetLength); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidInstanceLogicSet, req), true + } + if err := utils.CheckDbMetaDataFieldLen(req.GetMetadata()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidMetadata, req), true + } + if req.GetPort().GetValue() > 65535 { + return api.NewInstanceResponse(apimodel.Code_InvalidInstancePort, req), true + } + + if req.GetWeight().GetValue() > 65535 { + return api.NewInstanceResponse(apimodel.Code_InvalidParameter, req), true + } + return nil, false +} + +// 实例访问限流 +func (s *Server) allowInstanceAccess(instanceID string) bool { + if s.ratelimit == nil { + return true + } + + return s.ratelimit.Allow(plugin.InstanceRatelimit, instanceID) +} + +/** + * @brief 根据ip隔离和删除服务实例的参数检查 + */ +func checkInstanceByHost(req *apiservice.Instance) *apiservice.Response { + if req == nil { + return api.NewInstanceResponse(apimodel.Code_EmptyRequest, req) + } + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidServiceName, req) + } + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidNamespaceName, req) + } + if err := checkInstanceHost(req.GetHost()); err != nil { + return api.NewInstanceResponse(apimodel.Code_InvalidInstanceHost, req) + } + return nil +} + +// checkInstanceHost 检查服务实例Host +func checkInstanceHost(host *wrappers.StringValue) error { + if host == nil { + return errors.New(utils.NilErrString) + } + + if host.GetValue() == "" { + return errors.New(utils.EmptyErrString) + } + + return nil +} diff --git a/service/interceptor/paramcheck/ratelimit.go b/service/interceptor/paramcheck/ratelimit.go new file mode 100644 index 000000000..4720b462f --- /dev/null +++ b/service/interceptor/paramcheck/ratelimit.go @@ -0,0 +1,54 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" +) + +// CreateRateLimits implements service.DiscoverServer. +func (svr *Server) CreateRateLimits(ctx context.Context, + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateRateLimits(ctx, request) +} + +// DeleteRateLimits implements service.DiscoverServer. +func (svr *Server) DeleteRateLimits(ctx context.Context, + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteRateLimits(ctx, request) +} + +// EnableRateLimits implements service.DiscoverServer. +func (svr *Server) EnableRateLimits(ctx context.Context, + request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.EnableRateLimits(ctx, request) +} + +// GetRateLimits implements service.DiscoverServer. +func (svr *Server) GetRateLimits(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetRateLimits(ctx, query) +} + +// UpdateRateLimits implements service.DiscoverServer. +func (svr *Server) UpdateRateLimits(ctx context.Context, request []*traffic_manage.Rule) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateRateLimits(ctx, request) +} diff --git a/service/interceptor/paramcheck/route_rule.go b/service/interceptor/paramcheck/route_rule.go new file mode 100644 index 000000000..4da1fedad --- /dev/null +++ b/service/interceptor/paramcheck/route_rule.go @@ -0,0 +1,76 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + "github.com/polarismesh/specification/source/go/api/v1/traffic_manage" +) + +// UpdateRoutingConfigs implements service.DiscoverServer. +func (svr *Server) UpdateRoutingConfigs(ctx context.Context, req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateRoutingConfigs(ctx, req) +} + +// UpdateRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) UpdateRoutingConfigsV2(ctx context.Context, req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.UpdateRoutingConfigsV2(ctx, req) +} + +// QueryRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) QueryRoutingConfigsV2(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.QueryRoutingConfigsV2(ctx, query) +} + +// GetRoutingConfigs implements service.DiscoverServer. +func (svr *Server) GetRoutingConfigs(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetRoutingConfigs(ctx, query) +} + +// EnableRoutings implements service.DiscoverServer. +func (svr *Server) EnableRoutings(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.EnableRoutings(ctx, req) +} + +// CreateRoutingConfigs implements service.DiscoverServer. +func (svr *Server) CreateRoutingConfigs(ctx context.Context, + req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateRoutingConfigs(ctx, req) +} + +// DeleteRoutingConfigs implements service.DiscoverServer. +func (svr *Server) DeleteRoutingConfigs(ctx context.Context, + req []*traffic_manage.Routing) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteRoutingConfigs(ctx, req) +} + +// CreateRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) CreateRoutingConfigsV2(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.CreateRoutingConfigsV2(ctx, req) +} + +// DeleteRoutingConfigsV2 implements service.DiscoverServer. +func (svr *Server) DeleteRoutingConfigsV2(ctx context.Context, + req []*traffic_manage.RouteRule) *service_manage.BatchWriteResponse { + return svr.nextSvr.DeleteRoutingConfigsV2(ctx, req) +} diff --git a/service/interceptor/paramcheck/server.go b/service/interceptor/paramcheck/server.go index 83f62ecdd..542f1264b 100644 --- a/service/interceptor/paramcheck/server.go +++ b/service/interceptor/paramcheck/server.go @@ -18,8 +18,10 @@ package paramcheck import ( - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" + "github.com/polarismesh/polaris/common/log" "github.com/polarismesh/polaris/common/model" + "github.com/polarismesh/polaris/plugin" "github.com/polarismesh/polaris/service" ) @@ -27,18 +29,24 @@ import ( // // 该层会对请求参数做一些调整,根据具体的请求发起人,设置为数据对应的 owner,不可为为别人进行创建资源 type Server struct { - nextSvr service.DiscoverServer + nextSvr service.DiscoverServer + ratelimit plugin.Ratelimit } func NewServer(nextSvr service.DiscoverServer) service.DiscoverServer { proxy := &Server{ nextSvr: nextSvr, } + // 获取限流插件 + proxy.ratelimit = plugin.GetRatelimit() + if proxy.ratelimit == nil { + log.Warnf("Not found Ratelimit Plugin") + } return proxy } // Cache Get cache management -func (svr *Server) Cache() *cache.CacheManager { +func (svr *Server) Cache() cachetypes.CacheManager { return svr.nextSvr.Cache() } diff --git a/service/interceptor/paramcheck/service.go b/service/interceptor/paramcheck/service.go new file mode 100644 index 000000000..d46394f22 --- /dev/null +++ b/service/interceptor/paramcheck/service.go @@ -0,0 +1,352 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + "errors" + "strings" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/log" + "github.com/polarismesh/polaris/common/utils" +) + +var ( + serviceFilter = 1 // 过滤服务的 + instanceFilter = 2 // 过滤实例的 + serviceMetaFilter = 3 // 过滤service Metadata的 + instanceMetaFilter = 4 // 过滤instance Metadata的 + ServiceFilterAttributes = map[string]int{ + "name": serviceFilter, + "namespace": serviceFilter, + "business": serviceFilter, + "department": serviceFilter, + "cmdb_mod1": serviceFilter, + "cmdb_mod2": serviceFilter, + "cmdb_mod3": serviceFilter, + "owner": serviceFilter, + "offset": serviceFilter, + "limit": serviceFilter, + "platform_id": serviceFilter, + // 只返回存在健康实例的服务列表 + "only_exist_health_instance": serviceFilter, + "host": instanceFilter, + "port": instanceFilter, + "keys": serviceMetaFilter, + "values": serviceMetaFilter, + "instance_keys": instanceMetaFilter, + "instance_values": instanceMetaFilter, + } +) + +// CreateServices implements service.DiscoverServer. +func (svr *Server) CreateServices(ctx context.Context, + req []*service_manage.Service) *service_manage.BatchWriteResponse { + if checkError := checkBatchService(req); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + rsp := checkCreateService(req[i]) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.CreateServices(ctx, req) +} + +// DeleteServices implements service.DiscoverServer. +func (svr *Server) DeleteServices(ctx context.Context, + req []*service_manage.Service) *service_manage.BatchWriteResponse { + if checkError := checkBatchService(req); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + rsp := checkReviseService(req[i]) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteServices(ctx, req) +} + +// UpdateServices implements service.DiscoverServer. +func (svr *Server) UpdateServices(ctx context.Context, req []*service_manage.Service) *service_manage.BatchWriteResponse { + if checkError := checkBatchService(req); checkError != nil { + return checkError + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + rsp := checkReviseService(req[i]) + // 待更新的参数检查 + if err := checkMetadata(req[i].GetMetadata()); err != nil { + rsp = api.NewServiceResponse(apimodel.Code_InvalidMetadata, req[i]) + } + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.UpdateServices(ctx, req) +} + +// GetAllServices implements service.DiscoverServer. +func (svr *Server) GetAllServices(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetAllServices(ctx, query) +} + +// GetServiceOwner implements service.DiscoverServer. +func (svr *Server) GetServiceOwner(ctx context.Context, + req []*service_manage.Service) *service_manage.BatchQueryResponse { + if err := checkBatchReadService(req); err != nil { + return err + } + return svr.nextSvr.GetServiceOwner(ctx, req) +} + +// GetServiceToken implements service.DiscoverServer. +func (svr *Server) GetServiceToken(ctx context.Context, req *service_manage.Service) *service_manage.Response { + // 校验参数合法性 + if resp := checkReviseService(req); resp != nil { + return resp + } + return svr.nextSvr.GetServiceToken(ctx, req) +} + +// GetServices implements service.DiscoverServer. +func (svr *Server) GetServices(ctx context.Context, query map[string]string) *service_manage.BatchQueryResponse { + var ( + inputInstMetaKeys, inputInstMetaValues string + ) + for key, value := range query { + typ, ok := ServiceFilterAttributes[key] + if !ok { + log.Errorf("[Server][Service][Query] attribute(%s) it not allowed", key) + return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") + } + // 元数据value允许为空 + if key != "values" && value == "" { + log.Errorf("[Server][Service][Query] attribute(%s: %s) is not allowed empty", key, value) + return api.NewBatchQueryResponseWithMsg( + apimodel.Code_InvalidParameter, "the value for "+key+" is empty") + } + switch { + case typ == instanceMetaFilter: + if key == "instance_keys" { + inputInstMetaKeys = value + } else { + inputInstMetaValues = value + } + } + } + + if inputInstMetaKeys != "" { + instMetaKeys := strings.Split(inputInstMetaKeys, ",") + instMetaValues := strings.Split(inputInstMetaValues, ",") + if len(instMetaKeys) != len(instMetaValues) { + log.Errorf("[Server][Service][Query] length of instance meta %s and %s should be equal", + inputInstMetaKeys, inputInstMetaValues) + return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, + " length of instance_keys and instance_values are not equal") + } + } + + return svr.nextSvr.GetServices(ctx, query) +} + +// GetServicesCount implements service.DiscoverServer. +func (svr *Server) GetServicesCount(ctx context.Context) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetServicesCount(ctx) +} + +// UpdateServiceToken implements service.DiscoverServer. +func (svr *Server) UpdateServiceToken(ctx context.Context, req *service_manage.Service) *service_manage.Response { + // 校验参数合法性 + if resp := checkReviseService(req); resp != nil { + return resp + } + return svr.nextSvr.UpdateServiceToken(ctx, req) +} + +// checkBatchService检查批量请求 +func checkBatchService(req []*apiservice.Service) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// checkBatchReadService 检查批量读请求 +func checkBatchReadService(req []*apiservice.Service) *apiservice.BatchQueryResponse { + if len(req) == 0 { + return api.NewBatchQueryResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchQueryResponse(apimodel.Code_BatchSizeOverLimit) + } + + return nil +} + +// checkCreateService 检查创建服务请求参数 +func checkCreateService(req *apiservice.Service) *apiservice.Response { + if req == nil { + return api.NewServiceResponse(apimodel.Code_EmptyRequest, req) + } + + if err := utils.CheckResourceName(req.GetName()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req) + } + + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req) + } + + if err := checkMetadata(req.GetMetadata()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidMetadata, req) + } + + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbServiceFieldLen(req) + if notOk { + return err + } + + return nil +} + +// checkReviseService 检查删除/修改/服务token的服务请求参数 +func checkReviseService(req *apiservice.Service) *apiservice.Response { + if req == nil { + return api.NewServiceResponse(apimodel.Code_EmptyRequest, req) + } + + if err := utils.CheckResourceName(req.GetName()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req) + } + + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req) + } + + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbServiceFieldLen(req) + if notOk { + return err + } + + return nil +} + +// CheckDbServiceFieldLen 检查DB中service表对应的入参字段合法性 +func CheckDbServiceFieldLen(req *apiservice.Service) (*apiservice.Response, bool) { + if err := utils.CheckDbStrFieldLen(req.GetName(), utils.MaxDbServiceNameLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req), true + } + if err := utils.CheckDbMetaDataFieldLen(req.GetMetadata()); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidMetadata, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetPorts(), utils.MaxDbServicePortsLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServicePorts, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetBusiness(), utils.MaxDbServiceBusinessLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceBusiness, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetDepartment(), utils.MaxDbServiceDeptLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceDepartment, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetCmdbMod1(), utils.MaxDbServiceCMDBLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetCmdbMod2(), utils.MaxDbServiceCMDBLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetCmdbMod3(), utils.MaxDbServiceCMDBLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetComment(), utils.MaxDbServiceCommentLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceComment, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetOwners(), utils.MaxDbServiceOwnerLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceOwners, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetToken(), utils.MaxDbServiceToken); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidServiceToken, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetPlatformId(), utils.MaxPlatformIDLength); err != nil { + return api.NewServiceResponse(apimodel.Code_InvalidPlatformID, req), true + } + return nil, false +} + +// checkMetadata 检查metadata的个数; 最大是64个 +// key/value是否符合要求 +func checkMetadata(meta map[string]string) error { + if meta == nil { + return nil + } + + if len(meta) > utils.MaxMetadataLength { + return errors.New("metadata is too long") + } + + /*regStr := "^[0-9A-Za-z-._*]+$" + matchFunc := func(str string) error { + if str == "" { + return nil + } + ok, err := regexp.MatchString(regStr, str) + if err != nil { + log.Errorf("regexp match string(%s) err: %s", str, err.Error()) + return err + } + if !ok { + log.Errorf("metadata string(%s) contains invalid character", str) + return errors.New("contain invalid character") + } + return nil + } + for key, value := range meta { + if err := matchFunc(key); err != nil { + return err + } + if err := matchFunc(value); err != nil { + return err + } + }*/ + + return nil +} diff --git a/service/interceptor/paramcheck/service_alias.go b/service/interceptor/paramcheck/service_alias.go new file mode 100644 index 000000000..68c946044 --- /dev/null +++ b/service/interceptor/paramcheck/service_alias.go @@ -0,0 +1,180 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/utils" +) + +// CreateServiceAlias implements service.DiscoverServer. +func (svr *Server) CreateServiceAlias(ctx context.Context, + req *service_manage.ServiceAlias) *service_manage.Response { + if resp := checkCreateServiceAliasReq(ctx, req); resp != nil { + return resp + } + return svr.nextSvr.CreateServiceAlias(ctx, req) +} + +// DeleteServiceAliases implements service.DiscoverServer. +func (svr *Server) DeleteServiceAliases(ctx context.Context, + req []*service_manage.ServiceAlias) *service_manage.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + rsp := checkDeleteServiceAliasReq(ctx, req[i]) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + return svr.nextSvr.DeleteServiceAliases(ctx, req) +} + +// UpdateServiceAlias implements service.DiscoverServer. +func (svr *Server) UpdateServiceAlias(ctx context.Context, req *service_manage.ServiceAlias) *service_manage.Response { + // 检查请求参数 + if resp := checkReviseServiceAliasReq(ctx, req); resp != nil { + return resp + } + return svr.nextSvr.UpdateServiceAlias(ctx, req) +} + +// GetServiceAliases implements service.DiscoverServer. +func (svr *Server) GetServiceAliases(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetServiceAliases(ctx, query) +} + +// checkCreateServiceAliasReq 检查别名请求 +func checkCreateServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { + response, done := preCheckAlias(req) + if done { + return response + } + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbServiceAliasFieldLen(req) + if notOk { + return err + } + return nil +} + +// checkReviseServiceAliasReq 检查删除、修改别名请求 +func checkReviseServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { + resp := checkDeleteServiceAliasReq(ctx, req) + if resp != nil { + return resp + } + // 检查服务名 + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req) + } + + // 检查命名空间 + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req) + } + return nil +} + +// checkDeleteServiceAliasReq 检查删除、修改别名请求 +func checkDeleteServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { + if req == nil { + return api.NewServiceAliasResponse(apimodel.Code_EmptyRequest, req) + } + + // 检查服务别名 + if err := utils.CheckResourceName(req.GetAlias()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req) + } + + // 检查服务别名命名空间 + if err := utils.CheckResourceName(req.GetAliasNamespace()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceWithAlias, req) + } + + // 检查字段长度是否大于DB中对应字段长 + err, notOk := CheckDbServiceAliasFieldLen(req) + if notOk { + return err + } + + return nil +} +func preCheckAlias(req *apiservice.ServiceAlias) (*apiservice.Response, bool) { + if req == nil { + return api.NewServiceAliasResponse(apimodel.Code_EmptyRequest, req), true + } + + if err := utils.CheckResourceName(req.GetService()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req), true + } + + if err := utils.CheckResourceName(req.GetNamespace()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true + } + + if err := utils.CheckResourceName(req.GetAliasNamespace()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true + } + + // 默认类型,需要检查alias是否为空 + if req.GetType() == apiservice.AliasType_DEFAULT { + if err := utils.CheckResourceName(req.GetAlias()); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req), true + } + } + return nil, false +} + +// CheckDbServiceAliasFieldLen 检查DB中service表对应的入参字段合法性 +func CheckDbServiceAliasFieldLen(req *apiservice.ServiceAlias) (*apiservice.Response, bool) { + if err := utils.CheckDbStrFieldLen(req.GetService(), utils.MaxDbServiceNameLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetAlias(), utils.MaxDbServiceNameLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetAliasNamespace(), utils.MaxDbServiceNamespaceLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceWithAlias, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetComment(), utils.MaxDbServiceCommentLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAliasComment, req), true + } + if err := utils.CheckDbStrFieldLen(req.GetOwners(), utils.MaxDbServiceOwnerLength); err != nil { + return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAliasOwners, req), true + } + return nil, false +} diff --git a/service/interceptor/paramcheck/service_contract.go b/service/interceptor/paramcheck/service_contract.go new file mode 100644 index 000000000..de2b9c51a --- /dev/null +++ b/service/interceptor/paramcheck/service_contract.go @@ -0,0 +1,135 @@ +/** + * Tencent is pleased to support the open source community by making Polaris available. + * + * Copyright (C) 2019 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the BSD 3-Clause License (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://opensource.org/licenses/BSD-3-Clause + * + * Unless required by applicable law or agreed to in writing, software distributed + * under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package paramcheck + +import ( + "context" + + apimodel "github.com/polarismesh/specification/source/go/api/v1/model" + "github.com/polarismesh/specification/source/go/api/v1/service_manage" + apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" + + api "github.com/polarismesh/polaris/common/api/v1" + "github.com/polarismesh/polaris/common/utils" +) + +// CreateServiceContracts implements service.DiscoverServer. +func (svr *Server) CreateServiceContracts(ctx context.Context, + req []*service_manage.ServiceContract) *service_manage.BatchWriteResponse { + if rsp := checkBatchContractRules(req); rsp != nil { + return rsp + } + batchRsp := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) + for i := range req { + rsp := checkBaseServiceContract(req[i]) + api.Collect(batchRsp, rsp) + } + if !api.IsSuccess(batchRsp) { + return batchRsp + } + + return svr.nextSvr.CreateServiceContracts(ctx, req) +} + +// DeleteServiceContracts implements service.DiscoverServer. +func (svr *Server) DeleteServiceContracts(ctx context.Context, + req []*service_manage.ServiceContract) *service_manage.BatchWriteResponse { + if rsp := checkBatchContractRules(req); rsp != nil { + return rsp + } + return svr.nextSvr.DeleteServiceContracts(ctx, req) +} + +// GetServiceContractVersions implements service.DiscoverServer. +func (svr *Server) GetServiceContractVersions(ctx context.Context, + filter map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetServiceContractVersions(ctx, filter) +} + +// GetServiceContracts implements service.DiscoverServer. +func (svr *Server) GetServiceContracts(ctx context.Context, + query map[string]string) *service_manage.BatchQueryResponse { + return svr.nextSvr.GetServiceContracts(ctx, query) +} + +// CreateServiceContractInterfaces implements service.DiscoverServer. +func (svr *Server) CreateServiceContractInterfaces(ctx context.Context, + contract *service_manage.ServiceContract, source service_manage.InterfaceDescriptor_Source) *service_manage.Response { + if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { + return errRsp + } + return svr.nextSvr.CreateServiceContractInterfaces(ctx, contract, source) +} + +// AppendServiceContractInterfaces implements service.DiscoverServer. +func (svr *Server) AppendServiceContractInterfaces(ctx context.Context, + contract *service_manage.ServiceContract, + source service_manage.InterfaceDescriptor_Source) *service_manage.Response { + if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { + return errRsp + } + return svr.nextSvr.AppendServiceContractInterfaces(ctx, contract, source) +} + +// DeleteServiceContractInterfaces implements service.DiscoverServer. +func (svr *Server) DeleteServiceContractInterfaces(ctx context.Context, + contract *service_manage.ServiceContract) *service_manage.Response { + if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { + return errRsp + } + return svr.nextSvr.DeleteServiceContractInterfaces(ctx, contract) +} + +func checkBaseServiceContract(req *apiservice.ServiceContract) *apiservice.Response { + if err := utils.CheckResourceName(utils.NewStringValue(req.GetNamespace())); err != nil { + return api.NewResponse(apimodel.Code_InvalidNamespaceName) + } + if req.GetName() == "" { + return api.NewResponseWithMsg(apimodel.Code_BadRequest, "invalid service_contract name") + } + if req.GetProtocol() == "" { + return api.NewResponseWithMsg(apimodel.Code_BadRequest, "invalid service_contract protocol") + } + return nil +} + +func checkOperationServiceContractInterface(contract *apiservice.ServiceContract) *apiservice.Response { + if contract.Id != "" { + return nil + } + if err := checkBaseServiceContract(contract); err != nil { + return err + } + id, errRsp := utils.CheckContractTetrad(contract) + if errRsp != nil { + return errRsp + } + contract.Id = id + return nil +} + +func checkBatchContractRules(req []*service_manage.ServiceContract) *apiservice.BatchWriteResponse { + if len(req) == 0 { + return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) + } + + if len(req) > utils.MaxBatchSize { + return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) + } + return nil +} diff --git a/service/interceptor/register.go b/service/interceptor/register.go index 9c301cbda..399ba8294 100644 --- a/service/interceptor/register.go +++ b/service/interceptor/register.go @@ -25,14 +25,14 @@ import ( ) func init() { - err := service.RegisterServerProxy("paramcheck", func(svr *service.Server, pre service.DiscoverServer) (service.DiscoverServer, error) { - return paramcheck.NewServer(svr), nil + err := service.RegisterServerProxy("paramcheck", func(pre service.DiscoverServer) (service.DiscoverServer, error) { + return paramcheck.NewServer(pre), nil }) if err != nil { panic(err) } - err = service.RegisterServerProxy("auth", func(svr *service.Server, pre service.DiscoverServer) (service.DiscoverServer, error) { + err = service.RegisterServerProxy("auth", func(pre service.DiscoverServer) (service.DiscoverServer, error) { userMgn, err := auth.GetUserServer() if err != nil { return nil, err @@ -42,7 +42,7 @@ func init() { return nil, err } - return service_auth.NewServerAuthAbility(svr, userMgn, strategyMgn), nil + return service_auth.NewServerAuthAbility(pre, userMgn, strategyMgn), nil }) if err != nil { panic(err) diff --git a/service/namespace_test.go b/service/namespace_test.go index 3b35f585e..6451492ee 100644 --- a/service/namespace_test.go +++ b/service/namespace_test.go @@ -139,7 +139,7 @@ func TestRemoveNamespace(t *testing.T) { reqs = append(reqs, resp) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() discoverSuit.removeCommonNamespaces(t, reqs) t.Logf("pass") }) @@ -178,7 +178,7 @@ func TestUpdateNamespace(t *testing.T) { req, resp := discoverSuit.createCommonNamespace(t, 200) defer discoverSuit.cleanNamespace(resp.GetName().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() req.Token = resp.Token req.Comment = utils.NewStringValue("new-comment") diff --git a/service/options.go b/service/options.go index d3b0f0fed..3435c4e28 100644 --- a/service/options.go +++ b/service/options.go @@ -87,6 +87,9 @@ var ( { Name: cachetypes.FaultDetectRuleName, }, + { + Name: cachetypes.LaneRuleName, + }, } ) @@ -110,13 +113,17 @@ func WithStorage(storage store.Store) InitOption { } } -func WithCacheManager(cacheOpt *cache.Config, c *cache.CacheManager) InitOption { +func WithCacheManager(cacheOpt *cache.Config, c cachetypes.CacheManager, entries ...cachetypes.ConfigEntry) InitOption { return func(s *Server) { log.Infof("[Naming][Server] cache is open, can access the client api function") - _ = c.OpenResourceCache(namingCacheEntries...) - _ = c.OpenResourceCache(governanceCacheEntries...) - if s.isSupportL5() { - _ = c.OpenResourceCache(l5CacheEntry) + if len(entries) != 0 { + _ = c.OpenResourceCache(entries...) + } else { + _ = c.OpenResourceCache(namingCacheEntries...) + _ = c.OpenResourceCache(governanceCacheEntries...) + if s.isSupportL5() { + _ = c.OpenResourceCache(l5CacheEntry) + } } s.caches = c } diff --git a/service/ratelimit_config_test.go b/service/ratelimit_config_test.go index e8f89a04b..0e93d0653 100644 --- a/service/ratelimit_config_test.go +++ b/service/ratelimit_config_test.go @@ -33,6 +33,7 @@ import ( "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/wrapperspb" + "github.com/polarismesh/polaris/cache" api "github.com/polarismesh/polaris/common/api/v1" "github.com/polarismesh/polaris/common/utils" ) @@ -115,7 +116,7 @@ func TestCreateRateLimit(t *testing.T) { defer discoverSuit.cleanRateLimitRevision(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) t.Run("正常创建限流规则", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().Clear() + _ = discoverSuit.CacheMgr().Clear() time.Sleep(5 * time.Second) @@ -123,13 +124,13 @@ func TestCreateRateLimit(t *testing.T) { defer discoverSuit.cleanRateLimit(rateLimitResp.GetId().GetValue()) // 等待缓存更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetRateLimitWithCache(context.Background(), serviceResp) checkRateLimit(t, rateLimitReq, resp.GetRateLimit().GetRules()[0]) }) t.Run("创建限流规则,删除,再创建,可以正常创建", func(t *testing.T) { - _ = discoverSuit.DiscoverServer().Cache().Clear() + _ = discoverSuit.CacheMgr().Clear() time.Sleep(5 * time.Second) rateLimitReq, rateLimitResp := discoverSuit.createCommonRateLimit(t, serviceResp, 3) @@ -140,7 +141,7 @@ func TestCreateRateLimit(t *testing.T) { } // 等待缓存更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.DiscoverServer().Cache().(*cache.CacheManager).TestUpdate() resp := discoverSuit.DiscoverServer().GetRateLimitWithCache(context.Background(), serviceResp) checkRateLimit(t, rateLimitReq, resp.GetRateLimit().GetRules()[0]) discoverSuit.cleanRateLimit(rateLimitResp.GetId().GetValue()) @@ -413,6 +414,22 @@ func TestUpdateRateLimit(t *testing.T) { t.Run("04-并发更新限流规则时,可以正常更新", func(t *testing.T) { var wg sync.WaitGroup errs := make(chan error) + + lock := &sync.RWMutex{} + waitDelSvcs := []*apiservice.Service{} + waitDelRules := []*apitraffic.Rule{} + + t.Cleanup(func() { + for i := range waitDelSvcs { + serviceResp := waitDelSvcs[i] + discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) + } + for i := range waitDelRules { + rateLimitResp := waitDelRules[i] + discoverSuit.cleanRateLimit(rateLimitResp.GetId().GetValue()) + } + }) + for i := 1; i <= 50; i++ { wg.Add(1) go func(index int) { @@ -423,12 +440,15 @@ func TestUpdateRateLimit(t *testing.T) { updateRateLimitContent(rateLimitResp, index+1) discoverSuit.updateRateLimit(t, rateLimitResp) - t.Cleanup(func() { - discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - discoverSuit.cleanRateLimit(rateLimitResp.GetId().GetValue()) - }) + func() { + lock.Lock() + defer lock.Unlock() + + waitDelSvcs = append(waitDelSvcs, serviceResp) + waitDelRules = append(waitDelRules, rateLimitResp) + }() - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() filters := map[string]string{ "service": serviceResp.GetName().GetValue(), diff --git a/service/routing_config_v1_test.go b/service/routing_config_v1_test.go index 4bb6f6ee0..a38857670 100644 --- a/service/routing_config_v1_test.go +++ b/service/routing_config_v1_test.go @@ -106,7 +106,7 @@ func TestCreateRoutingConfig(t *testing.T) { _, _ = discoverSuit.createCommonRoutingConfig(t, serviceResp, 3, 0) // 对写进去的数据进行查询 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) if !respSuccess(out) { @@ -151,7 +151,7 @@ func TestCreateRoutingConfig(t *testing.T) { _, serviceResp := discoverSuit.createCommonService(t, 120) discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() req := &apitraffic.Routing{} req.Service = serviceResp.Name req.Namespace = serviceResp.Namespace @@ -184,7 +184,7 @@ func TestUpdateRoutingConfig(t *testing.T) { assert.True(t, respSuccess(uResp)) // 等缓存层更新 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 直接查询存储无法查询到 v1 的路由规则 total, routingsV1, err := discoverSuit.Storage.GetRoutingConfigs(map[string]string{}, 0, 100) @@ -239,7 +239,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { routingResps = append(routingResps, routingResp) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() for i := 0; i < total; i++ { t.Logf("service : name=%s namespace=%s", serviceResps[i].GetName().GetValue(), serviceResps[i].GetNamespace().GetValue()) out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResps[i]) @@ -269,7 +269,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { t.Fatal(svcResp.Info) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() t.Logf("service : name=%s namespace=%s", svcName, namespaceName) out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue(svcName), @@ -278,7 +278,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { assert.True(t, len(out.GetRouting().GetOutbounds()) == 0, "inBounds must be zero") - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() enableResp := discoverSuit.DiscoverServer().EnableRoutings(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{ { @@ -291,7 +291,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { t.Fatal(enableResp.Info) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out = discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue(svcName), Namespace: utils.NewStringValue(namespaceName), @@ -374,7 +374,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { t.Fatal(svcResp.Info) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() t.Logf("service : name=%s namespace=%s", svcName, namespaceName) out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue(svcName), @@ -382,7 +382,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { }) assert.True(t, len(out.GetRouting().GetOutbounds()) == 0, "inBounds must be zero") - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() enableResp := discoverSuit.DiscoverServer().EnableRoutings(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{ { Id: resp[0].Id, @@ -394,13 +394,13 @@ func TestGetRoutingConfigWithCache(t *testing.T) { t.Fatal(enableResp.Info) } - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out = discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue(svcName), Namespace: utils.NewStringValue(namespaceName), }) - assert.True(t, len(out.GetRouting().GetOutbounds()) == 1, "inBounds must be one") + assert.True(t, len(out.GetRouting().GetOutbounds()) == 0, "inBounds must be zero") }) t.Run("服务路由数据不改变,传递了路由revision,不返回数据", func(t *testing.T) { @@ -416,7 +416,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { _, routingResp := discoverSuit.createCommonRoutingConfig(t, serviceResp, 2, 0) defer discoverSuit.cleanCommonRoutingConfig(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() firstResp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp) checkSameRoutingConfig(t, routingResp, firstResp.GetRouting()) @@ -436,7 +436,7 @@ func TestGetRoutingConfigWithCache(t *testing.T) { _, serviceResp := discoverSuit.createCommonService(t, 10) defer discoverSuit.cleanServiceName(serviceResp.GetName().GetValue(), serviceResp.GetNamespace().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() if resp := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, serviceResp); !respSuccess(resp) { t.Fatalf("error: %s", resp.GetInfo().GetValue()) } diff --git a/service/routing_config_v2_test.go b/service/routing_config_v2_test.go index aea61853e..af000b722 100644 --- a/service/routing_config_v2_test.go +++ b/service/routing_config_v2_test.go @@ -53,7 +53,7 @@ func TestCreateRoutingConfigV2(t *testing.T) { defer discoverSuit.truncateCommonRoutingConfigV2() // 对写进去的数据进行查询 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", "offset": "0", @@ -157,7 +157,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { discoverSuit.truncateCommonRoutingConfigV2() }) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 从缓存中查询应该查到 3+3 条 v2 的路由规则 out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", @@ -190,7 +190,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { discoverSuit.truncateCommonRoutingConfigV2() }) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 从缓存中查询应该查到 3 条 v2 的路由规则 out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", @@ -245,7 +245,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { discoverSuit.truncateCommonRoutingConfigV2() }) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 从缓存中查询应该查到 3+3 条 v2 的路由规则 out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", @@ -275,7 +275,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { assert.NoError(t, err, err) assert.Nil(t, ruleV2, "v2 routing must delete") - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 从缓存中查询应该查到 2 条 v2 的路由规则 out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", @@ -307,7 +307,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { discoverSuit.truncateCommonRoutingConfigV2() }) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() // 从缓存中查询应该查到 3+3 条 v2 的路由规则 out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", @@ -340,7 +340,7 @@ func TestCompatibleRoutingConfigV2AndV1(t *testing.T) { assert.NotNil(t, ruleV2, "v2 routing must exist") assert.Equal(t, rulesV2[0].Description, ruleV2.Description) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", "offset": "0", @@ -382,7 +382,7 @@ func TestDeleteRoutingConfigV2(t *testing.T) { namespaceName := fmt.Sprintf("in-source-service-%d", 0) // 删除之后,数据不见 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out := discoverSuit.DiscoverServer().GetRoutingConfigWithCache(discoverSuit.DefaultCtx, &apiservice.Service{ Name: utils.NewStringValue(serviceName), Namespace: utils.NewStringValue(namespaceName), @@ -408,7 +408,7 @@ func TestUpdateRoutingConfigV2(t *testing.T) { req := discoverSuit.createCommonRoutingConfigV2(t, 1) defer discoverSuit.cleanCommonRoutingConfigV2(req) // 对写进去的数据进行查询 - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out := discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", "offset": "0", @@ -427,7 +427,7 @@ func TestUpdateRoutingConfigV2(t *testing.T) { routing.Name = updateName discoverSuit.DiscoverServer().UpdateRoutingConfigsV2(discoverSuit.DefaultCtx, []*apitraffic.RouteRule{routing}) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() out = discoverSuit.DiscoverServer().QueryRoutingConfigsV2(discoverSuit.DefaultCtx, map[string]string{ "limit": "100", "offset": "0", diff --git a/service/server.go b/service/server.go index e472e30b5..4ea474c49 100644 --- a/service/server.go +++ b/service/server.go @@ -23,7 +23,7 @@ import ( apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" "golang.org/x/sync/singleflight" - "github.com/polarismesh/polaris/cache" + cachetypes "github.com/polarismesh/polaris/cache/api" cacheservice "github.com/polarismesh/polaris/cache/service" "github.com/polarismesh/polaris/common/eventhub" "github.com/polarismesh/polaris/common/model" @@ -43,14 +43,13 @@ type Server struct { namespaceSvr namespace.NamespaceOperateServer - caches *cache.CacheManager + caches cachetypes.CacheManager bc *batch.Controller healthServer *healthcheck.Server cmdb plugin.CMDB history plugin.History - ratelimit plugin.Ratelimit l5service *l5service @@ -77,13 +76,17 @@ func (s *Server) allowAutoCreate() bool { return *s.config.AutoCreate } +func (s *Server) Store() store.Store { + return s.storage +} + // HealthServer 健康检查Server func (s *Server) HealthServer() *healthcheck.Server { return s.healthServer } // Cache 返回Cache -func (s *Server) Cache() *cache.CacheManager { +func (s *Server) Cache() cachetypes.CacheManager { return s.caches } @@ -154,15 +157,6 @@ func (s *Server) getLocation(host string) *model.Location { return location } -// 实例访问限流 -func (s *Server) allowInstanceAccess(instanceID string) bool { - if s.ratelimit == nil { - return true - } - - return s.ratelimit.Allow(plugin.InstanceRatelimit, instanceID) -} - func (s *Server) afterServiceResource(ctx context.Context, req *apiservice.Service, save *model.Service, remove bool) error { event := &ResourceEvent{ diff --git a/service/service.go b/service/service.go index 4f9195c2d..0330b4895 100644 --- a/service/service.go +++ b/service/service.go @@ -74,10 +74,6 @@ var ( // CreateServices 批量创建服务 func (s *Server) CreateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse { - if checkError := checkBatchService(req); checkError != nil { - return checkError - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, service := range req { response := s.CreateService(ctx, service) @@ -89,13 +85,6 @@ func (s *Server) CreateServices(ctx context.Context, req []*apiservice.Service) // CreateService 创建单个服务 func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - // 参数检查 - if checkError := checkCreateService(req); checkError != nil { - return checkError - } - if _, errResp := s.createNamespaceIfAbsent(ctx, req); errResp != nil { return errResp } @@ -106,8 +95,7 @@ func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *ap // 检查命名空间是否存在 namespace, err := s.storage.GetNamespace(namespaceName) if err != nil { - log.Error("[Service] get namespace fail", - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("[Service] get namespace fail", utils.RequestID(ctx), zap.Error(err)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if namespace == nil { @@ -117,8 +105,7 @@ func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *ap // 检查是否存在 service, err := s.storage.GetService(serviceName, namespaceName) if err != nil { - log.Error("[Service] get service fail", - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("[Service] get service fail", utils.RequestID(ctx), zap.Error(err)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if service != nil { @@ -129,14 +116,13 @@ func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *ap // 存储层操作 data := s.createServiceModel(req) if err := s.storage.AddService(data); err != nil { - log.Error("[Service] save service fail", - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID), zap.Error(err)) + log.Error("[Service] save service fail", utils.RequestID(ctx), zap.Error(err)) // 如果在存储层发现资源存在错误,则需要再一次从存储层获取响应的信息,填充响应的 svc_id 信息 if commonstore.StoreCode2APICode(err) == apimodel.Code_ExistedResource { // 检查是否存在 service, err := s.storage.GetService(serviceName, namespaceName) if err != nil { - log.Error("[Service] get service fail", utils.ZapRequestID(requestID), zap.Error(err)) + log.Error("[Service] get service fail", utils.RequestID(ctx), zap.Error(err)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if service != nil { @@ -147,9 +133,8 @@ func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *ap return wrapperServiceStoreResponse(req, err) } - msg := fmt.Sprintf("create service: namespace=%v, name=%v, meta=%+v", - namespaceName, serviceName, req.GetMetadata()) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(fmt.Sprintf("create service: namespace=%v, name=%v, meta=%+v", + namespaceName, serviceName, req.GetMetadata()), utils.RequestID(ctx)) s.RecordHistory(ctx, serviceRecordEntry(ctx, req, data, model.OCreate)) out := &apiservice.Service{ @@ -168,10 +153,6 @@ func (s *Server) CreateService(ctx context.Context, req *apiservice.Service) *ap // DeleteServices 批量删除服务 func (s *Server) DeleteServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse { - if checkError := checkBatchService(req); checkError != nil { - return checkError - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, service := range req { response := s.DeleteService(ctx, service) @@ -186,21 +167,13 @@ func (s *Server) DeleteServices(ctx context.Context, req []*apiservice.Service) // 删除操作需要对服务进行加锁操作, // 防止有与服务关联的实例或者配置有新增的操作 func (s *Server) DeleteService(ctx context.Context, req *apiservice.Service) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - - // 参数检查 - if checkError := checkReviseService(req); checkError != nil { - return checkError - } - namespaceName := req.GetNamespace().GetValue() serviceName := req.GetName().GetValue() // 检查是否存在 service, err := s.storage.GetService(serviceName, namespaceName) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if service == nil { @@ -208,17 +181,17 @@ func (s *Server) DeleteService(ctx context.Context, req *apiservice.Service) *ap } // 判断service下的资源是否已经全部被删除 - if resp := s.isServiceExistedResource(requestID, platformID, service); resp != nil { + if resp := s.isServiceExistedResource(ctx, service); resp != nil { return resp } if err := s.storage.DeleteService(service.ID, serviceName, namespaceName); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperServiceStoreResponse(req, err) } msg := fmt.Sprintf("delete service: namespace=%v, name=%v", namespaceName, serviceName) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, serviceRecordEntry(ctx, req, nil, model.ODelete)) if err := s.afterServiceResource(ctx, req, service, true); err != nil { @@ -229,10 +202,6 @@ func (s *Server) DeleteService(ctx context.Context, req *apiservice.Service) *ap // UpdateServices 批量修改服务 func (s *Server) UpdateServices(ctx context.Context, req []*apiservice.Service) *apiservice.BatchWriteResponse { - if checkError := checkBatchService(req); checkError != nil { - return checkError - } - responses := api.NewBatchWriteResponse(apimodel.Code_ExecuteSuccess) for _, service := range req { response := s.UpdateService(ctx, service) @@ -244,13 +213,6 @@ func (s *Server) UpdateServices(ctx context.Context, req []*apiservice.Service) // UpdateService 修改单个服务 func (s *Server) UpdateService(ctx context.Context, req *apiservice.Service) *apiservice.Response { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParsePlatformID(ctx) - // 校验基础参数合法性 - if resp := checkReviseService(req); resp != nil { - return resp - } - // 鉴权 service, _, resp := s.checkServiceAuthority(ctx, req) if resp != nil { @@ -262,7 +224,7 @@ func (s *Server) UpdateService(ctx context.Context, req *apiservice.Service) *ap return api.NewServiceResponse(apimodel.Code_NotAllowAliasUpdate, req) } - log.Info(fmt.Sprintf("old service: %+v", service), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(fmt.Sprintf("old service: %+v", service), utils.RequestID(ctx)) // 修改 err, needUpdate, needUpdateOwner := s.updateServiceAttribute(req, service) @@ -272,7 +234,7 @@ func (s *Server) UpdateService(ctx context.Context, req *apiservice.Service) *ap // 判断是否需要更新 if !needUpdate { log.Info("update service data no change, no need update", - utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID), zap.String("service", req.String())) + utils.RequestID(ctx), zap.String("service", req.String())) if err := s.afterServiceResource(ctx, req, service, false); err != nil { return api.NewServiceResponse(apimodel.Code_ExecuteException, req) } @@ -282,12 +244,12 @@ func (s *Server) UpdateService(ctx context.Context, req *apiservice.Service) *ap // 存储层操作 if err := s.storage.UpdateService(service, needUpdateOwner); err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperServiceStoreResponse(req, err) } msg := fmt.Sprintf("update service: namespace=%v, name=%v", service.Namespace, service.Name) - log.Info(msg, utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Info(msg, utils.RequestID(ctx)) s.RecordHistory(ctx, serviceRecordEntry(ctx, req, service, model.OUpdate)) if err := s.afterServiceResource(ctx, req, service, false); err != nil { @@ -299,11 +261,6 @@ func (s *Server) UpdateService(ctx context.Context, req *apiservice.Service) *ap // UpdateServiceToken 更新服务token func (s *Server) UpdateServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { - // 校验参数合法性 - if resp := checkReviseService(req); resp != nil { - return resp - } - // 鉴权 service, _, resp := s.checkServiceAuthority(ctx, req) if resp != nil { @@ -312,20 +269,18 @@ func (s *Server) UpdateServiceToken(ctx context.Context, req *apiservice.Service if service.IsAlias() { return api.NewServiceResponse(apimodel.Code_NotAllowAliasUpdate, req) } - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) // 生成一个新的token和revision service.Token = utils.NewUUID() service.Revision = utils.NewUUID() // 更新数据库 if err := s.storage.UpdateServiceToken(service.ID, service.Token, service.Revision); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperServiceStoreResponse(req, err) } log.Info("update service token", zap.String("namespace", service.Namespace), zap.String("name", service.Name), zap.String("service-id", service.ID), - utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + utils.RequestID(ctx)) s.RecordHistory(ctx, serviceRecordEntry(ctx, req, service, model.OUpdateToken)) // 填充新的token返回 @@ -377,17 +332,7 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis inputInstMetaKeys, inputInstMetaValues string ) for key, value := range query { - typ, ok := ServiceFilterAttributes[key] - if !ok { - log.Errorf("[Server][Service][Query] attribute(%s) it not allowed", key) - return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, key+" is not allowed") - } - // 元数据value允许为空 - if key != "values" && value == "" { - log.Errorf("[Server][Service][Query] attribute(%s: %s) is not allowed empty", key, value) - return api.NewBatchQueryResponseWithMsg( - apimodel.Code_InvalidParameter, "the value for "+key+" is empty") - } + typ, _ := ServiceFilterAttributes[key] switch { case typ == serviceFilter: serviceFilters[key] = value @@ -412,12 +357,6 @@ func (s *Server) GetServices(ctx context.Context, query map[string]string) *apis if inputInstMetaKeys != "" { instMetaKeys := strings.Split(inputInstMetaKeys, ",") instMetaValues := strings.Split(inputInstMetaValues, ",") - if len(instMetaKeys) != len(instMetaValues) { - log.Errorf("[Server][Service][Query] length of instance meta %s and %s should be equal", - inputInstMetaKeys, inputInstMetaValues) - return api.NewBatchQueryResponseWithMsg(apimodel.Code_InvalidParameter, - " length of instance_keys and instance_values are not equal") - } for idx, key := range instMetaKeys { instanceMetas[key] = instMetaValues[idx] } @@ -515,11 +454,6 @@ func (s *Server) GetServicesCount(ctx context.Context) *apiservice.BatchQueryRes // GetServiceToken 查询Service的token func (s *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) *apiservice.Response { - // 校验参数合法性 - if resp := checkReviseService(req); resp != nil { - return resp - } - // 鉴权 _, token, resp := s.checkServiceAuthority(ctx, req) if resp != nil { @@ -538,16 +472,9 @@ func (s *Server) GetServiceToken(ctx context.Context, req *apiservice.Service) * // GetServiceOwner 查询服务负责人 func (s *Server) GetServiceOwner(ctx context.Context, req []*apiservice.Service) *apiservice.BatchQueryResponse { - requestID := utils.ParseRequestID(ctx) - platformID := utils.ParseRequestID(ctx) - - if err := checkBatchReadService(req); err != nil { - return err - } - services, err := s.storage.GetServicesBatch(apis2ServicesName(req)) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(requestID), utils.ZapPlatformID(platformID)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewBatchQueryResponseWithMsg(commonstore.StoreCode2APICode(err), err.Error()) } @@ -592,11 +519,6 @@ func (s *Server) createServiceModel(req *apiservice.Service) *model.Service { // updateServiceAttribute 修改服务属性 func (s *Server) updateServiceAttribute( req *apiservice.Service, service *model.Service) (*apiservice.Response, bool, bool) { - // 待更新的参数检查 - if err := checkMetadata(req.GetMetadata()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidMetadata, req), false, false - } - var ( needUpdate = false needNewRevision = false @@ -737,7 +659,7 @@ func (s *Server) getRateLimitingCountWithService(name string, namespace string) } // isServiceExistedResource 检查服务下的资源存在情况,在删除服务的时候需要用到 -func (s *Server) isServiceExistedResource(rid, pid string, service *model.Service) *apiservice.Response { +func (s *Server) isServiceExistedResource(ctx context.Context, service *model.Service) *apiservice.Response { // 服务别名,不需要判断 if service.IsAlias() { return nil @@ -748,7 +670,7 @@ func (s *Server) isServiceExistedResource(rid, pid string, service *model.Servic } total, err := s.getInstancesCountWithService(service.Name, service.Namespace) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), out) } if total != 0 { @@ -757,7 +679,7 @@ func (s *Server) isServiceExistedResource(rid, pid string, service *model.Servic total, err = s.getServiceAliasCountWithService(service.Name, service.Namespace) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), out) } if total != 0 { @@ -767,7 +689,7 @@ func (s *Server) isServiceExistedResource(rid, pid string, service *model.Servic // TODO will remove until have sync router rule v1 to v2 total, err = s.getRoutingCountWithService(service.ID) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceResponse(commonstore.StoreCode2APICode(err), out) } @@ -781,15 +703,13 @@ func (s *Server) isServiceExistedResource(rid, pid string, service *model.Servic // return service, token, response func (s *Server) checkServiceAuthority(ctx context.Context, req *apiservice.Service) (*model.Service, string, *apiservice.Response) { - rid := utils.ParseRequestID(ctx) - pid := utils.ParsePlatformID(ctx) namespaceName := req.GetNamespace().GetValue() serviceName := req.GetName().GetValue() // 检查是否存在 svc, err := s.storage.GetService(serviceName, namespaceName) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, "", api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if svc == nil { @@ -798,7 +718,7 @@ func (s *Server) checkServiceAuthority(ctx context.Context, req *apiservice.Serv if svc.Reference != "" { svc, err = s.storage.GetServiceByID(svc.Reference) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid), utils.ZapPlatformID(pid)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, "", api.NewServiceResponse(commonstore.StoreCode2APICode(err), req) } if svc == nil { @@ -948,82 +868,6 @@ func serviceMetaNeedUpdate(req *apiservice.Service, service *model.Service) bool return needUpdate } -// checkBatchService检查批量请求 -func checkBatchService(req []*apiservice.Service) *apiservice.BatchWriteResponse { - if len(req) == 0 { - return api.NewBatchWriteResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchWriteResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// checkBatchReadService 检查批量读请求 -func checkBatchReadService(req []*apiservice.Service) *apiservice.BatchQueryResponse { - if len(req) == 0 { - return api.NewBatchQueryResponse(apimodel.Code_EmptyRequest) - } - - if len(req) > MaxBatchSize { - return api.NewBatchQueryResponse(apimodel.Code_BatchSizeOverLimit) - } - - return nil -} - -// checkCreateService 检查创建服务请求参数 -func checkCreateService(req *apiservice.Service) *apiservice.Response { - if req == nil { - return api.NewServiceResponse(apimodel.Code_EmptyRequest, req) - } - - if err := utils.CheckResourceName(req.GetName()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req) - } - - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req) - } - - if err := checkMetadata(req.GetMetadata()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidMetadata, req) - } - - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbServiceFieldLen(req) - if notOk { - return err - } - - return nil -} - -// checkReviseService 检查删除/修改/服务token的服务请求参数 -func checkReviseService(req *apiservice.Service) *apiservice.Response { - if req == nil { - return api.NewServiceResponse(apimodel.Code_EmptyRequest, req) - } - - if err := utils.CheckResourceName(req.GetName()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req) - } - - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req) - } - - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbServiceFieldLen(req) - if notOk { - return err - } - - return nil -} - // wrapperServiceStoreResponse wrapper service error func wrapperServiceStoreResponse(service *apiservice.Service, err error) *apiservice.Response { if err == nil { @@ -1062,47 +906,3 @@ func serviceRecordEntry(ctx context.Context, req *apiservice.Service, md *model. return entry } - -// CheckDbServiceFieldLen 检查DB中service表对应的入参字段合法性 -func CheckDbServiceFieldLen(req *apiservice.Service) (*apiservice.Response, bool) { - if err := utils.CheckDbStrFieldLen(req.GetName(), MaxDbServiceNameLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceName, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidNamespaceName, req), true - } - if err := utils.CheckDbMetaDataFieldLen(req.GetMetadata()); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidMetadata, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetPorts(), MaxDbServicePortsLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServicePorts, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetBusiness(), MaxDbServiceBusinessLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceBusiness, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetDepartment(), MaxDbServiceDeptLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceDepartment, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetCmdbMod1(), MaxDbServiceCMDBLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetCmdbMod2(), MaxDbServiceCMDBLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetCmdbMod3(), MaxDbServiceCMDBLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceCMDB, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetComment(), MaxDbServiceCommentLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceComment, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetOwners(), MaxDbServiceOwnerLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceOwners, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetToken(), MaxDbServiceToken); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidServiceToken, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetPlatformId(), MaxPlatformIDLength); err != nil { - return api.NewServiceResponse(apimodel.Code_InvalidPlatformID, req), true - } - return nil, false -} diff --git a/service/service_alias.go b/service/service_alias.go index 34e00a244..e95a6400f 100644 --- a/service/service_alias.go +++ b/service/service_alias.go @@ -49,19 +49,14 @@ var ( // CreateServiceAlias 创建服务别名 func (s *Server) CreateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - if resp := checkCreateServiceAliasReq(ctx, req); resp != nil { - return resp - } - - rid := utils.ParseRequestID(ctx) tx, err := s.storage.CreateTransaction() if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } defer func() { _ = tx.Commit() }() - service, response, done := s.checkPointServiceAlias(tx, req, rid) + service, response, done := s.checkPointServiceAlias(ctx, tx, req) if done { return response } @@ -71,7 +66,7 @@ func (s *Server) CreateServiceAlias(ctx context.Context, req *apiservice.Service oldAlias, getErr := s.storage.GetService(req.GetAlias().GetValue(), req.GetAliasNamespace().GetValue()) if getErr != nil { - log.Error(getErr.Error(), utils.ZapRequestID(rid)) + log.Error(getErr.Error(), utils.RequestID(ctx)) return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } if oldAlias != nil { @@ -85,12 +80,12 @@ func (s *Server) CreateServiceAlias(ctx context.Context, req *apiservice.Service return resp } if err := s.storage.AddService(input); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } log.Info(fmt.Sprintf("create service alias, service(%s, %s), alias(%s, %s)", - req.Service.Value, req.Namespace.Value, input.Name, input.Namespace), utils.ZapRequestID(rid)) + req.Service.Value, req.Namespace.Value, input.Name, input.Namespace), utils.RequestID(ctx)) out := &apiservice.ServiceAlias{ Service: req.Service, Namespace: req.Namespace, @@ -106,12 +101,12 @@ func (s *Server) CreateServiceAlias(ctx context.Context, req *apiservice.Service return api.NewServiceAliasResponse(apimodel.Code_ExecuteSuccess, out) } -func (s *Server) checkPointServiceAlias( - tx store.Transaction, req *apiservice.ServiceAlias, rid string) (*model.Service, *apiservice.Response, bool) { +func (s *Server) checkPointServiceAlias(ctx context.Context, + tx store.Transaction, req *apiservice.ServiceAlias) (*model.Service, *apiservice.Response, bool) { // 检查指向服务是否存在以及是否为别名 service, err := tx.LockService(req.GetService().GetValue(), req.GetNamespace().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return nil, api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req), true } if service == nil { @@ -129,9 +124,6 @@ func (s *Server) checkPointServiceAlias( // 需要带上源服务name,namespace,token // 另外一种删除别名的方式,是直接调用删除服务的接口,也是可行的 func (s *Server) DeleteServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - if resp := checkDeleteServiceAliasReq(ctx, req); resp != nil { - return resp - } rid := utils.ParseRequestID(ctx) alias, err := s.storage.GetService(req.GetAlias().GetValue(), req.GetAliasNamespace().GetValue()) @@ -183,22 +175,10 @@ func (s *Server) DeleteServiceAliases( // UpdateServiceAlias 修改服务别名 func (s *Server) UpdateServiceAlias(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - rid := utils.ParseRequestID(ctx) - - // 检查请求参数 - if resp := checkReviseServiceAliasReq(ctx, req); resp != nil { - return resp - } - - // 检查别名负责人 - // if err := checkResourceOwners(req.GetOwners()); err != nil { - // return api.NewServiceAliasResponse(api.InvalidServiceAliasOwners, req) - // } - // 检查服务别名是否存在 alias, err := s.storage.GetService(req.GetAlias().GetValue(), req.GetAliasNamespace().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } if alias == nil { @@ -208,7 +188,7 @@ func (s *Server) UpdateServiceAlias(ctx context.Context, req *apiservice.Service // 检查将要指向的服务是否存在 service, err := s.storage.GetService(req.GetService().GetValue(), req.GetNamespace().GetValue()) if err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return api.NewServiceAliasResponse(commonstore.StoreCode2APICode(err), req) } if service == nil { @@ -226,19 +206,19 @@ func (s *Server) UpdateServiceAlias(ctx context.Context, req *apiservice.Service } if !needUpdate { - log.Info("update service alias data no change, no need update", utils.ZapRequestID(rid), + log.Info("update service alias data no change, no need update", utils.RequestID(ctx), zap.String("service alias", req.String())) return api.NewServiceAliasResponse(apimodel.Code_NoNeedUpdate, req) } // 执行存储层操作 if err := s.storage.UpdateServiceAlias(alias, needUpdateOwner); err != nil { - log.Error(err.Error(), utils.ZapRequestID(rid)) + log.Error(err.Error(), utils.RequestID(ctx)) return wrapperServiceAliasResponse(req, err) } log.Info(fmt.Sprintf("update service alias, service(%s, %s), alias(%s)", - req.GetService().GetValue(), req.GetNamespace().GetValue(), req.GetAlias().GetValue()), utils.ZapRequestID(rid)) + req.GetService().GetValue(), req.GetNamespace().GetValue(), req.GetAlias().GetValue()), utils.RequestID(ctx)) record := &apiservice.Service{Name: req.Alias, Namespace: req.Namespace} s.RecordHistory(ctx, serviceRecordEntry(ctx, record, alias, model.OUpdate)) @@ -292,89 +272,6 @@ func (s *Server) GetServiceAliases(ctx context.Context, query map[string]string) return resp } -// checkCreateServiceAliasReq 检查别名请求 -func checkCreateServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - response, done := preCheckAlias(req) - if done { - return response - } - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbServiceAliasFieldLen(req) - if notOk { - return err - } - return nil -} - -func preCheckAlias(req *apiservice.ServiceAlias) (*apiservice.Response, bool) { - if req == nil { - return api.NewServiceAliasResponse(apimodel.Code_EmptyRequest, req), true - } - - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req), true - } - - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true - } - - if err := utils.CheckResourceName(req.GetAliasNamespace()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true - } - - // 默认类型,需要检查alias是否为空 - if req.GetType() == apiservice.AliasType_DEFAULT { - if err := utils.CheckResourceName(req.GetAlias()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req), true - } - } - return nil, false -} - -// checkReviseServiceAliasReq 检查删除、修改别名请求 -func checkReviseServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - resp := checkDeleteServiceAliasReq(ctx, req) - if resp != nil { - return resp - } - // 检查服务名 - if err := utils.CheckResourceName(req.GetService()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req) - } - - // 检查命名空间 - if err := utils.CheckResourceName(req.GetNamespace()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req) - } - return nil -} - -// checkDeleteServiceAliasReq 检查删除、修改别名请求 -func checkDeleteServiceAliasReq(ctx context.Context, req *apiservice.ServiceAlias) *apiservice.Response { - if req == nil { - return api.NewServiceAliasResponse(apimodel.Code_EmptyRequest, req) - } - - // 检查服务别名 - if err := utils.CheckResourceName(req.GetAlias()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req) - } - - // 检查服务别名命名空间 - if err := utils.CheckResourceName(req.GetAliasNamespace()); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceWithAlias, req) - } - - // 检查字段长度是否大于DB中对应字段长 - err, notOk := CheckDbServiceAliasFieldLen(req) - if notOk { - return err - } - - return nil -} - // updateServiceAliasAttribute 修改服务别名属性 func (s *Server) updateServiceAliasAttribute(req *apiservice.ServiceAlias, alias *model.Service, serviceID string) ( *apiservice.Response, bool, bool) { @@ -454,26 +351,3 @@ func wrapperServiceAliasResponse(alias *apiservice.ServiceAlias, err error) *api resp.Alias = alias return resp } - -// CheckDbServiceAliasFieldLen 检查DB中service表对应的入参字段合法性 -func CheckDbServiceAliasFieldLen(req *apiservice.ServiceAlias) (*apiservice.Response, bool) { - if err := utils.CheckDbStrFieldLen(req.GetService(), MaxDbServiceNameLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceName, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceName, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetAlias(), MaxDbServiceNameLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAlias, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetAliasNamespace(), MaxDbServiceNamespaceLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidNamespaceWithAlias, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetComment(), MaxDbServiceCommentLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAliasComment, req), true - } - if err := utils.CheckDbStrFieldLen(req.GetOwners(), MaxDbServiceOwnerLength); err != nil { - return api.NewServiceAliasResponse(apimodel.Code_InvalidServiceAliasOwners, req), true - } - return nil, false -} diff --git a/service/service_alias_test.go b/service/service_alias_test.go index 2c69f7c61..4beee56f7 100644 --- a/service/service_alias_test.go +++ b/service/service_alias_test.go @@ -572,7 +572,7 @@ func TestServiceAliasRelated(t *testing.T) { t.Run("实例Discover,别名查询实例,返回源服务的实例信息", func(t *testing.T) { _, instanceResp := discoverSuit.createCommonInstance(t, serviceResp, 123) defer discoverSuit.cleanInstance(instanceResp.GetId().GetValue()) - _ = discoverSuit.DiscoverServer().Cache().TestUpdate() + _ = discoverSuit.CacheMgr().TestUpdate() service := &apiservice.Service{Name: resp.Alias.Alias, Namespace: resp.Alias.Namespace} disResp := discoverSuit.DiscoverServer().ServiceInstancesCache(discoverSuit.DefaultCtx, &apiservice.DiscoverFilter{}, service) assert.True(t, api.IsSuccess(disResp), disResp.GetInfo().GetValue()) diff --git a/service/service_contract.go b/service/service_contract.go index c6cb2267a..a9e9f0eab 100644 --- a/service/service_contract.go +++ b/service/service_contract.go @@ -64,9 +64,6 @@ func (s *Server) CreateServiceContracts(ctx context.Context, } func (s *Server) CreateServiceContract(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response { - if errRsp := checkBaseServiceContract(contract); errRsp != nil { - return errRsp - } contractId := contract.GetId() if contractId == "" { tmpId, errRsp := utils.CheckContractTetrad(contract) @@ -293,8 +290,12 @@ func (s *Server) GetServiceContractVersions(ctx context.Context, filter map[stri func (s *Server) CreateServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, source apiservice.InterfaceDescriptor_Source) *apiservice.Response { - if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { - return errRsp + if contract.Id == "" { + id, errRsp := utils.CheckContractTetrad(contract) + if errRsp != nil { + return errRsp + } + contract.Id = id } createData := &model.EnrichServiceContract{ @@ -335,9 +336,14 @@ func (s *Server) CreateServiceContractInterfaces(ctx context.Context, func (s *Server) AppendServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract, source apiservice.InterfaceDescriptor_Source) *apiservice.Response { - if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { - return errRsp + if contract.Id == "" { + id, errRsp := utils.CheckContractTetrad(contract) + if errRsp != nil { + return errRsp + } + contract.Id = id } + saveData, err := s.storage.GetServiceContract(contract.Id) if err != nil { log.Error("[Service][Contract] get save service_contract when append interfaces", utils.RequestID(ctx), zap.Error(err)) @@ -385,8 +391,12 @@ func (s *Server) AppendServiceContractInterfaces(ctx context.Context, func (s *Server) DeleteServiceContractInterfaces(ctx context.Context, contract *apiservice.ServiceContract) *apiservice.Response { - if errRsp := checkOperationServiceContractInterface(contract); errRsp != nil { - return errRsp + if contract.Id == "" { + id, errRsp := utils.CheckContractTetrad(contract) + if errRsp != nil { + return errRsp + } + contract.Id = id } saveData, err := s.storage.GetServiceContract(contract.Id) @@ -429,21 +439,6 @@ func (s *Server) DeleteServiceContractInterfaces(ctx context.Context, return api.NewAnyDataResponse(apimodel.Code_ExecuteSuccess, nil) } -func checkOperationServiceContractInterface(contract *apiservice.ServiceContract) *apiservice.Response { - if contract.Id != "" { - return nil - } - if err := checkBaseServiceContract(contract); err != nil { - return err - } - id, errRsp := utils.CheckContractTetrad(contract) - if errRsp != nil { - return errRsp - } - contract.Id = id - return nil -} - // serviceContractRecordEntry 生成服务的记录entry func serviceContractRecordEntry(ctx context.Context, req *apiservice.ServiceContract, data *model.EnrichServiceContract, operationType model.OperationType) *model.RecordEntry { @@ -463,16 +458,3 @@ func serviceContractRecordEntry(ctx context.Context, req *apiservice.ServiceCont return entry } - -func checkBaseServiceContract(req *apiservice.ServiceContract) *apiservice.Response { - if err := utils.CheckResourceName(utils.NewStringValue(req.GetNamespace())); err != nil { - return api.NewResponse(apimodel.Code_InvalidNamespaceName) - } - if req.GetName() == "" { - return api.NewResponseWithMsg(apimodel.Code_BadRequest, "invalid service_contract name") - } - if req.GetProtocol() == "" { - return api.NewResponseWithMsg(apimodel.Code_BadRequest, "invalid service_contract protocol") - } - return nil -} diff --git a/service/test_export.go b/service/test_export.go index 63666f29a..658f0f64c 100644 --- a/service/test_export.go +++ b/service/test_export.go @@ -19,7 +19,6 @@ package service import ( "context" - "fmt" apimodel "github.com/polarismesh/specification/source/go/api/v1/model" apiservice "github.com/polarismesh/specification/source/go/api/v1/service_manage" @@ -66,44 +65,18 @@ func TestInitialize(ctx context.Context, namingOpt *Config, cacheOpt *cache.Conf if len(cacheEntries) != 0 { entrites = cacheEntries } else { - entrites = append(entrites, l5CacheEntry) - entrites = append(entrites, namingCacheEntries...) - entrites = append(entrites, governanceCacheEntries...) + entrites = GetAllCaches() } - _ = cacheMgr.OpenResourceCache(entrites...) - namingServer.healthServer = healthSvr - namingServer.storage = storage - // 注入命名空间管理模块 - namingServer.namespaceSvr = namespaceSvr - // cache模块,可以不开启 - // 对于控制台集群,只访问控制台接口的,可以不开启cache - log.Infof("[Naming][Server] cache is open, can access the client api function") - namingServer.caches = cacheMgr - namingServer.bc = bc - // l5service - namingServer.l5service = &l5service{} - namingServer.createServiceSingle = &singleflight.Group{} - // 插件初始化 - pluginInitialize() - - var proxySvr DiscoverServer - var err error - // 需要返回包装代理的 DiscoverServer - order := namingOpt.Interceptors - for i := range order { - factory, exist := serverProxyFactories[order[i]] - if !exist { - return nil, nil, fmt.Errorf("name(%s) not exist in serverProxyFactories", order[i]) - } - - proxySvr, err = factory(namingServer, proxySvr) - if err != nil { - return nil, nil, err - } - } - - return proxySvr, namingServer, nil + actualSvr, proxySvr, err := InitServer(ctx, namingOpt, + WithBatchController(bc), + WithCacheManager(cacheOpt, cacheMgr, entrites...), + WithHealthCheckSvr(healthSvr), + WithNamespaceSvr(namespaceSvr), + WithStorage(storage), + ) + namingServer = actualSvr + return proxySvr, namingServer, err } // TestSerialCreateInstance . @@ -113,9 +86,9 @@ func (s *Server) TestSerialCreateInstance( return s.serialCreateInstance(ctx, svcId, req, ins) } -// TestCheckCreateInstance . -func TestCheckCreateInstance(req *apiservice.Instance) (string, *apiservice.Response) { - return checkCreateInstance(req) +// TestSetStore . +func (s *Server) TestSetStore(storage store.Store) { + s.storage = storage } // TestIsEmptyLocation . diff --git a/service/utils.go b/service/utils.go index 0289155ad..bf9055197 100644 --- a/service/utils.go +++ b/service/utils.go @@ -18,7 +18,6 @@ package service import ( - "errors" "fmt" "strconv" "strings" @@ -97,58 +96,6 @@ const ( ParamKeyInstanceId = "instanceId" ) -// checkInstanceHost 检查服务实例Host -func checkInstanceHost(host *wrappers.StringValue) error { - if host == nil { - return errors.New(utils.NilErrString) - } - - if host.GetValue() == "" { - return errors.New(utils.EmptyErrString) - } - - return nil -} - -// checkMetadata 检查metadata的个数; 最大是64个 -// key/value是否符合要求 -func checkMetadata(meta map[string]string) error { - if meta == nil { - return nil - } - - if len(meta) > MaxMetadataLength { - return errors.New("metadata is too long") - } - - /*regStr := "^[0-9A-Za-z-._*]+$" - matchFunc := func(str string) error { - if str == "" { - return nil - } - ok, err := regexp.MatchString(regStr, str) - if err != nil { - log.Errorf("regexp match string(%s) err: %s", str, err.Error()) - return err - } - if !ok { - log.Errorf("metadata string(%s) contains invalid character", str) - return errors.New("contain invalid character") - } - return nil - } - for key, value := range meta { - if err := matchFunc(key); err != nil { - return err - } - if err := matchFunc(value); err != nil { - return err - } - }*/ - - return nil -} - // storeError2AnyResponse store code func storeError2AnyResponse(err error, msg proto.Message) *apiservice.Response { if err == nil {