From 2f2be8afddbd3261e58f20c42aed75535b644d47 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 3 Oct 2022 09:35:24 -0700 Subject: [PATCH] Project settings (#480) Signed-off-by: Yee Hing Tong Signed-off-by: pmahindrakar-oss --- flyteadmin_config.yaml | 6 +- pkg/manager/impl/execution_manager.go | 89 ++--- pkg/manager/impl/execution_manager_test.go | 85 +++-- .../impl/resources/resource_manager.go | 159 ++++++++- .../impl/resources/resource_manager_test.go | 310 +++++++++++++++++- pkg/manager/impl/shared/iface.go | 25 ++ pkg/manager/impl/testutils/attributes.go | 15 + pkg/manager/impl/util/shared.go | 39 +++ pkg/manager/impl/util/shared_test.go | 137 ++++++++ .../impl/validation/attributes_validator.go | 15 + .../impl/validation/project_validator.go | 30 ++ .../impl/validation/project_validator_test.go | 60 ++++ pkg/manager/interfaces/resource.go | 9 +- pkg/manager/mocks/resource.go | 46 +++ pkg/repositories/gormimpl/resource_repo.go | 41 ++- .../gormimpl/resource_repo_test.go | 57 +++- pkg/repositories/interfaces/resource_repo.go | 3 + pkg/repositories/mocks/resource.go | 8 + pkg/repositories/models/resource.go | 2 +- pkg/repositories/transformers/resource.go | 21 +- .../transformers/resource_test.go | 37 ++- pkg/rpc/adminservice/attributes.go | 57 ++++ .../adminservice/tests/project_domain_test.go | 76 +++++ .../interfaces/application_configuration.go | 25 ++ 24 files changed, 1236 insertions(+), 116 deletions(-) create mode 100644 pkg/manager/impl/shared/iface.go diff --git a/flyteadmin_config.yaml b/flyteadmin_config.yaml index 4b666ad46..26897bd05 100644 --- a/flyteadmin_config.yaml +++ b/flyteadmin_config.yaml @@ -6,7 +6,7 @@ server: httpPort: 8088 grpcPort: 8089 grpcServerReflection: true - kube-config: /Users/haythamabuelfutuh/kubeconfig/k3s/k3s.yaml + kube-config: /Users/ytong/.kube/config security: secure: false useAuth: false @@ -66,7 +66,7 @@ database: port: 5432 username: postgres host: localhost - dbname: postgres + dbname: flyteadmin options: "sslmode=disable" scheduler: eventScheduler: @@ -122,7 +122,7 @@ storage: auth-type: accesskey secret-key: miniostorage disable-ssl: true - endpoint: "http://localhost:9000" + endpoint: "http://localhost:30084" region: my-region signedUrl: stowConfigOverride: diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index 69685353c..f5ab8a0d2 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -51,7 +51,6 @@ import ( "github.com/benbjohnson/clock" "github.com/flyteorg/flyteadmin/pkg/manager/impl/shared" "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes/wrappers" ) const childContainerQueueKey = "child_queue" @@ -434,59 +433,6 @@ func (m *ExecutionManager) getInheritedExecMetadata(ctx context.Context, request return parentNodeExecutionID, sourceExecutionID, nil } -// WorkflowExecutionConfigInterface is used as common interface for capturing the common behavior catering to the needs -// of fetching the WorkflowExecutionConfig across LaunchPlanSpec, ExecutionCreateRequest -// MatchableResource_WORKFLOW_EXECUTION_CONFIG and ApplicationConfig -type WorkflowExecutionConfigInterface interface { - // GetMaxParallelism Can be used to control the number of parallel nodes to run within the workflow. This is useful to achieve fairness. - GetMaxParallelism() int32 - // GetRawOutputDataConfig Encapsulates user settings pertaining to offloaded data (i.e. Blobs, Schema, query data, etc.). - GetRawOutputDataConfig() *admin.RawOutputDataConfig - // GetSecurityContext Indicates security context permissions for executions triggered with this matchable attribute. - GetSecurityContext() *core.SecurityContext - // GetAnnotations Custom annotations to be applied to a triggered execution resource. - GetAnnotations() *admin.Annotations - // GetLabels Custom labels to be applied to a triggered execution resource. - GetLabels() *admin.Labels - // GetInterruptible indicates a workflow should be flagged as interruptible for a single execution. If omitted, the workflow's default is used. - GetInterruptible() *wrappers.BoolValue -} - -// Merge into workflowExecConfig from spec and return true if any value has been changed -func mergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec WorkflowExecutionConfigInterface) admin.WorkflowExecutionConfig { - if workflowExecConfig.GetMaxParallelism() == 0 && spec.GetMaxParallelism() > 0 { - workflowExecConfig.MaxParallelism = spec.GetMaxParallelism() - } - - if workflowExecConfig.GetSecurityContext() == nil && spec.GetSecurityContext() != nil { - if spec.GetSecurityContext().GetRunAs() != nil && - (len(spec.GetSecurityContext().GetRunAs().GetK8SServiceAccount()) > 0 || - len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0) { - workflowExecConfig.SecurityContext = spec.GetSecurityContext() - } - } - // Launchplan spec has label, annotation and rawOutputDataConfig initialized with empty values. - // Hence we do a deep check in the following conditions before assignment - if (workflowExecConfig.GetRawOutputDataConfig() == nil || - len(workflowExecConfig.GetRawOutputDataConfig().GetOutputLocationPrefix()) == 0) && - (spec.GetRawOutputDataConfig() != nil && len(spec.GetRawOutputDataConfig().OutputLocationPrefix) > 0) { - workflowExecConfig.RawOutputDataConfig = spec.GetRawOutputDataConfig() - } - if (workflowExecConfig.GetLabels() == nil || len(workflowExecConfig.GetLabels().Values) == 0) && - (spec.GetLabels() != nil && len(spec.GetLabels().Values) > 0) { - workflowExecConfig.Labels = spec.GetLabels() - } - if (workflowExecConfig.GetAnnotations() == nil || len(workflowExecConfig.GetAnnotations().Values) == 0) && - (spec.GetAnnotations() != nil && len(spec.GetAnnotations().Values) > 0) { - workflowExecConfig.Annotations = spec.GetAnnotations() - } - - if workflowExecConfig.GetInterruptible() == nil && spec.GetInterruptible() != nil { - workflowExecConfig.Interruptible = spec.GetInterruptible() - } - return workflowExecConfig -} - // Produces execution-time attributes for workflow execution. // Defaults to overridable execution values set in the execution create request, then looks at the launch plan values // (if any) before defaulting to values set in the matchable resource db and further if matchable resources don't @@ -495,30 +441,49 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi launchPlan *admin.LaunchPlan) (*admin.WorkflowExecutionConfig, error) { workflowExecConfig := admin.WorkflowExecutionConfig{} - // merge the request spec into workflowExecConfig - workflowExecConfig = mergeIntoExecConfig(workflowExecConfig, request.Spec) + // Merge the request spec into workflowExecConfig + workflowExecConfig = util.MergeIntoExecConfig(workflowExecConfig, request.Spec) var workflowName string if launchPlan != nil && launchPlan.Spec != nil { - // merge the launch plan spec into workflowExecConfig - workflowExecConfig = mergeIntoExecConfig(workflowExecConfig, launchPlan.Spec) + // Merge the launch plan spec into workflowExecConfig + workflowExecConfig = util.MergeIntoExecConfig(workflowExecConfig, launchPlan.Spec) if launchPlan.Spec.WorkflowId != nil { workflowName = launchPlan.Spec.WorkflowId.Name } } + // This will get the most specific Workflow Execution Config. matchableResource, err := util.GetMatchableResource(ctx, m.resourceManager, admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, request.Project, request.Domain, workflowName) if err != nil { return nil, err } - if matchableResource != nil && matchableResource.Attributes.GetWorkflowExecutionConfig() != nil { // merge the matchable resource workflow execution config into workflowExecConfig - workflowExecConfig = mergeIntoExecConfig(workflowExecConfig, + workflowExecConfig = util.MergeIntoExecConfig(workflowExecConfig, matchableResource.Attributes.GetWorkflowExecutionConfig()) } + // To match what the front-end will display to the user, we need to do the project level query too. + // This searches only for a direct match, and will not merge in system config level defaults like the + // GetProjectAttributes call does, since that's done below. + // The reason we need to do the project level query is for the case where some configs (say max parallelism) + // is set on the project level, but other items (say service account) is set on the project-domain level. + // In this case you want to use the project-domain service account, the project-level max parallelism, and + // system level defaults for the rest. + // See FLYTE-2322 for more background information. + projectMatchableResource, err := util.GetMatchableResource(ctx, m.resourceManager, + admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, request.Project, "", "") + if err != nil { + return nil, err + } + if projectMatchableResource != nil && projectMatchableResource.Attributes.GetWorkflowExecutionConfig() != nil { + // merge the matchable resource workflow execution config into workflowExecConfig + workflowExecConfig = util.MergeIntoExecConfig(workflowExecConfig, + projectMatchableResource.Attributes.GetWorkflowExecutionConfig()) + } + // Backward compatibility changes to get security context from auth role. // Older authRole or auth fields in the launchplan spec or execution request need to be used over application defaults. // This portion of the code makes sure if newer way of setting security context is empty i.e @@ -530,8 +495,8 @@ func (m *ExecutionManager) getExecutionConfig(ctx context.Context, request *admi len(resolvedSecurityCtx.GetRunAs().GetIamRole()) > 0) { workflowExecConfig.SecurityContext = resolvedSecurityCtx } - // merge the application config into workflowExecConfig. If even the deprecated fields are not set - workflowExecConfig = mergeIntoExecConfig(workflowExecConfig, m.config.ApplicationConfiguration().GetTopLevelConfig()) + // Merge the application config into workflowExecConfig. If even the deprecated fields are not set + workflowExecConfig = util.MergeIntoExecConfig(workflowExecConfig, m.config.ApplicationConfiguration().GetTopLevelConfig()) // Explicitly set the security context if its nil since downstream we expect this settings to be available if workflowExecConfig.GetSecurityContext() == nil { workflowExecConfig.SecurityContext = &core.SecurityContext{ diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 61c681345..309f3cde0 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -4026,18 +4026,21 @@ func TestGetExecutionConfigOverrides(t *testing.T) { } resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + // two requests will be made, one with empty domain and one with filled in domain + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) - return &managerInterfaces.ResourceResponse{ + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) + projectDomainResponse := &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ MaxParallelism: rmMaxParallelism, Interruptible: &wrappers.BoolValue{Value: rmInterruptible}, - Labels: &admin.Labels{Values: rmLabels}, Annotations: &admin.Annotations{Values: rmAnnotations}, RawOutputDataConfig: &admin.RawOutputDataConfig{ OutputLocationPrefix: rmOutputLocationPrefix, @@ -4050,7 +4053,24 @@ func TestGetExecutionConfigOverrides(t *testing.T) { }, }, }, - }, nil + } + + projectResponse := &managerInterfaces.ResourceResponse{ + Attributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ + Labels: &admin.Labels{Values: rmLabels}, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "shouldnotbeused", + }, + }, + }, + }, + } + if request.Domain == "" { + return projectResponse, nil + } + return projectDomainResponse, nil } t.Run("request with full config", func(t *testing.T) { @@ -4234,11 +4254,15 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("matchable resource partial config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) + return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ @@ -4275,11 +4299,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("matchable resource with no config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ @@ -4308,11 +4335,15 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("fetch security context from deprecated config", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) + return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ @@ -4345,12 +4376,17 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("matchable resource workflow resource", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, - Workflow: workflowIdentifier.Name, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + Workflow: workflowIdentifier.Name, + }, {Project: workflowIdentifier.Project, + Domain: "", + Workflow: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) + return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ @@ -4391,11 +4427,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("matchable resource failure", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) return nil, fmt.Errorf("failed to fetch the resources") } request := &admin.ExecutionCreateRequest{ @@ -4417,11 +4456,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { t.Run("application configuration", func(t *testing.T) { resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ @@ -4580,11 +4622,14 @@ func TestGetExecutionConfig(t *testing.T) { resourceManager := managerMocks.MockResourceManager{} resourceManager.GetResourceFunc = func(ctx context.Context, request managerInterfaces.ResourceRequest) (*managerInterfaces.ResourceResponse, error) { - assert.EqualValues(t, request, managerInterfaces.ResourceRequest{ + assert.Contains(t, []managerInterfaces.ResourceRequest{{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, - }) + }, {Project: workflowIdentifier.Project, + Domain: "", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG}, + }, request) return &managerInterfaces.ResourceResponse{ Attributes: &admin.MatchingAttributes{ Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ diff --git a/pkg/manager/impl/resources/resource_manager.go b/pkg/manager/impl/resources/resource_manager.go index 870575c70..a02cd024a 100644 --- a/pkg/manager/impl/resources/resource_manager.go +++ b/pkg/manager/impl/resources/resource_manager.go @@ -3,6 +3,8 @@ package resources import ( "context" + "github.com/flyteorg/flyteadmin/pkg/manager/impl/util" + "github.com/flyteorg/flyteadmin/pkg/repositories/models" "github.com/flyteorg/flyteadmin/pkg/errors" @@ -148,6 +150,123 @@ func (m *ResourceManager) DeleteWorkflowAttributes(ctx context.Context, return &admin.WorkflowAttributesDeleteResponse{}, nil } +func (m *ResourceManager) UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) { + + var resource admin.MatchableResource + var err error + + if resource, err = validation.ValidateProjectAttributesUpdateRequest(ctx, m.db, request); err != nil { + return nil, err + } + model, err := transformers.ProjectAttributesToResourceModel(*request.Attributes, resource) + if err != nil { + return nil, err + } + + if request.Attributes.GetMatchingAttributes().GetPluginOverrides() != nil { + return m.createOrMergeUpdateProjectAttributes(ctx, request, model, admin.MatchableResource_PLUGIN_OVERRIDE) + } + + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + + return &admin.ProjectAttributesUpdateResponse{}, nil +} + +func (m *ResourceManager) GetProjectAttributesBase(ctx context.Context, request admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) { + + if err := validation.ValidateProjectExists(ctx, m.db, request.Project); err != nil { + return nil, err + } + + projectAttributesModel, err := m.db.ResourceRepo().GetProjectLevel( + ctx, repo_interface.ResourceID{Project: request.Project, Domain: "", ResourceType: request.ResourceType.String()}) + if err != nil { + return nil, err + } + + ma, err := transformers.FromResourceModelToMatchableAttributes(projectAttributesModel) + if err != nil { + return nil, err + } + + return &admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: request.Project, + MatchingAttributes: ma.Attributes, + }, + }, nil +} + +// GetProjectAttributes combines the call to the database to get the Project level settings with +// Admin server level configuration. +// Note this merge is only done for WorkflowExecutionConfig +// This code should be removed pending implementation of a complete settings implementation. +func (m *ResourceManager) GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) { + + getResponse, err := m.GetProjectAttributesBase(ctx, request) + configLevelDefaults := m.config.GetTopLevelConfig().GetAsWorkflowExecutionConfig() + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + // TODO: Will likely be removed after overarching settings project is done + // Proceed with the default CreateOrUpdate call since there's no existing model to update. + return &admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: request.Project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &configLevelDefaults, + }, + }, + }, + }, nil + } + return nil, err + + } + // If found, then merge result with the default values for the platform + // TODO: Remove this logic once the overarching settings project is done. Those endpoints should take + // default configuration into account. + responseAttributes := getResponse.Attributes.GetMatchingAttributes().GetWorkflowExecutionConfig() + if responseAttributes != nil { + logger.Warningf(ctx, "Merging response %s with defaults %s", responseAttributes, configLevelDefaults) + tmp := util.MergeIntoExecConfig(*responseAttributes, &configLevelDefaults) + responseAttributes = &tmp + return &admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: request.Project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: responseAttributes, + }, + }, + }, + }, nil + } + + return getResponse, nil +} + +func (m *ResourceManager) DeleteProjectAttributes(ctx context.Context, request admin.ProjectAttributesDeleteRequest) ( + *admin.ProjectAttributesDeleteResponse, error) { + + if err := validation.ValidateProjectForUpdate(ctx, m.db, request.Project); err != nil { + return nil, err + } + if err := m.db.ResourceRepo().Delete( + ctx, repo_interface.ResourceID{Project: request.Project, ResourceType: request.ResourceType.String()}); err != nil { + return nil, err + } + logger.Infof(ctx, "Deleted project attributes for: %s-%s (%s)", request.Project, request.ResourceType.String()) + return &admin.ProjectAttributesDeleteResponse{}, nil +} + func (m *ResourceManager) createOrMergeUpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest, model models.Resource, resourceType admin.MatchableResource) (*admin.ProjectDomainAttributesUpdateResponse, error) { @@ -171,8 +290,8 @@ func (m *ResourceManager) createOrMergeUpdateProjectDomainAttributes( } return nil, err } - updatedModel, err := transformers.MergeUpdateProjectDomainAttributes( - ctx, existing, resourceType, &resourceID, request.Attributes) + updatedModel, err := transformers.MergeUpdatePluginAttributes( + ctx, existing, resourceType, &resourceID, request.Attributes.MatchingAttributes) if err != nil { return nil, err } @@ -183,6 +302,42 @@ func (m *ResourceManager) createOrMergeUpdateProjectDomainAttributes( return &admin.ProjectDomainAttributesUpdateResponse{}, nil } +func (m *ResourceManager) createOrMergeUpdateProjectAttributes( + ctx context.Context, request admin.ProjectAttributesUpdateRequest, model models.Resource, + resourceType admin.MatchableResource) (*admin.ProjectAttributesUpdateResponse, error) { + + resourceID := repo_interface.ResourceID{ + Project: model.Project, + Domain: model.Domain, + Workflow: model.Workflow, + LaunchPlan: model.LaunchPlan, + ResourceType: model.ResourceType, + } + existing, err := m.db.ResourceRepo().GetRaw(ctx, resourceID) + if err != nil { + ec, ok := err.(errors.FlyteAdminError) + if ok && ec.Code() == codes.NotFound { + // Proceed with the default CreateOrUpdate call since there's no existing model to update. + err = m.db.ResourceRepo().CreateOrUpdate(ctx, model) + if err != nil { + return nil, err + } + return &admin.ProjectAttributesUpdateResponse{}, nil + } + return nil, err + } + updatedModel, err := transformers.MergeUpdatePluginAttributes( + ctx, existing, resourceType, &resourceID, request.Attributes.MatchingAttributes) + if err != nil { + return nil, err + } + err = m.db.ResourceRepo().CreateOrUpdate(ctx, updatedModel) + if err != nil { + return nil, err + } + return &admin.ProjectAttributesUpdateResponse{}, nil +} + func (m *ResourceManager) UpdateProjectDomainAttributes( ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) { diff --git a/pkg/manager/impl/resources/resource_manager_test.go b/pkg/manager/impl/resources/resource_manager_test.go index 7a18ab092..ad886e52e 100644 --- a/pkg/manager/impl/resources/resource_manager_test.go +++ b/pkg/manager/impl/resources/resource_manager_test.go @@ -2,8 +2,15 @@ package resources import ( "context" + + runtimeInterfaces "github.com/flyteorg/flyteadmin/pkg/runtime/interfaces" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + + // pkg/runtime/interfaces/application_configuration.go "testing" + runtimeMocks "github.com/flyteorg/flyteadmin/pkg/runtime/mocks" + "github.com/flyteorg/flyteadmin/pkg/errors" "google.golang.org/grpc/codes" @@ -22,6 +29,8 @@ import ( const project = "project" const domain = "domain" const workflow = "workflow" +const python = "python" +const hive = "hive" func TestUpdateWorkflowAttributes(t *testing.T) { request := admin.WorkflowAttributesUpdateRequest{ @@ -124,9 +133,9 @@ func TestUpdateWorkflowAttributes_CreateOrMerge(t *testing.T) { assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 2) for _, override := range attributesToBeSaved.GetPluginOverrides().Overrides { - if override.TaskType == "python" { + if override.TaskType == python { assert.EqualValues(t, []string{"plugin a"}, override.PluginId) - } else if override.TaskType == "hive" { + } else if override.TaskType == hive { assert.EqualValues(t, []string{"plugin b"}, override.PluginId) } else { t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType) @@ -253,7 +262,7 @@ func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) { } assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 1) assert.True(t, proto.Equal(attributesToBeSaved.GetPluginOverrides().Overrides[0], &admin.PluginOverride{ - TaskType: "python", + TaskType: python, PluginId: []string{"plugin a"}})) createOrUpdateCalled = true @@ -295,9 +304,9 @@ func TestUpdateProjectDomainAttributes_CreateOrMerge(t *testing.T) { assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 2) for _, override := range attributesToBeSaved.GetPluginOverrides().Overrides { - if override.TaskType == "python" { + if override.TaskType == python { assert.EqualValues(t, []string{"plugin a"}, override.PluginId) - } else if override.TaskType == "hive" { + } else if override.TaskType == hive { assert.EqualValues(t, []string{"plugin b"}, override.PluginId) } else { t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType) @@ -365,6 +374,297 @@ func TestDeleteProjectDomainAttributes(t *testing.T) { assert.Nil(t, err) } +func TestUpdateProjectAttributes(t *testing.T) { + request := admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: testutils.WorkflowExecutionConfigSample, + }, + } + db := mocks.NewMockRepository() + expectedSerializedAttrs, _ := proto.Marshal(testutils.WorkflowExecutionConfigSample) + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, "", input.Domain) + assert.Equal(t, "", input.Workflow) + assert.Equal(t, admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG.String(), input.ResourceType) + assert.EqualValues(t, expectedSerializedAttrs, input.Attributes) + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, createOrUpdateCalled) + + // Test empty attributes + request = admin.ProjectAttributesUpdateRequest{Attributes: nil} + _, err = manager.UpdateProjectAttributes(context.Background(), request) + assert.Error(t, err) + + // Test error handling + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func( + ctx context.Context, input models.Resource) error { + return errors.NewFlyteAdminErrorf(123, "123") + } + request = admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: testutils.WorkflowExecutionConfigSample, + }, + } + _, err = manager.UpdateProjectAttributes(context.Background(), request) + assert.Error(t, err, "123") +} + +func TestUpdateProjectAttributes_CreateOrMerge(t *testing.T) { + request := admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: commonTestUtils.GetPluginOverridesAttributes(map[string][]string{"python": {"plugin a"}}), + }, + } + + t.Run("create only", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "foo") + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, "", input.Domain) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 1) + assert.True(t, proto.Equal(attributesToBeSaved.GetPluginOverrides().Overrides[0], &admin.PluginOverride{ + TaskType: python, + PluginId: []string{"plugin a"}})) + + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateProjectAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) + t.Run("merge update", func(t *testing.T) { + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func(ctx context.Context, ID repoInterfaces.ResourceID) ( + models.Resource, error) { + existingAttributes := commonTestUtils.GetPluginOverridesAttributes(map[string][]string{ + "hive": {"plugin b"}, + "python": {"plugin c"}, + }) + bytes, err := proto.Marshal(existingAttributes) + if err != nil { + t.Fatal(err) + } + return models.Resource{ + Project: project, + Attributes: bytes, + }, nil + } + var createOrUpdateCalled bool + db.ResourceRepo().(*mocks.MockResourceRepo).CreateOrUpdateFunction = func(ctx context.Context, input models.Resource) error { + assert.Equal(t, project, input.Project) + assert.Equal(t, "", input.Domain) + + var attributesToBeSaved admin.MatchingAttributes + err := proto.Unmarshal(input.Attributes, &attributesToBeSaved) + if err != nil { + t.Fatal(err) + } + + assert.Len(t, attributesToBeSaved.GetPluginOverrides().Overrides, 2) + for _, override := range attributesToBeSaved.GetPluginOverrides().Overrides { + if override.TaskType == python { + assert.EqualValues(t, []string{"plugin a"}, override.PluginId) + } else if override.TaskType == hive { + assert.EqualValues(t, []string{"plugin b"}, override.PluginId) + } else { + t.Errorf("Unexpected task type [%s] plugin override committed to db", override.TaskType) + } + } + createOrUpdateCalled = true + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.UpdateProjectAttributes(context.Background(), request) + assert.NoError(t, err) + assert.True(t, createOrUpdateCalled) + }) +} + +func TestGetProjectAttributes(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, + } + db := mocks.NewMockRepository() + + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) (models.Resource, error) { + + assert.Equal(t, project, ID.Project) + assert.Equal(t, "", ID.Domain) + assert.Equal(t, "", ID.Workflow) + assert.Equal(t, admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG.String(), ID.ResourceType) + expectedSerializedAttrs, _ := proto.Marshal(testutils.WorkflowExecutionConfigSample) + return models.Resource{ + Project: project, + Domain: "", + ResourceType: "resource", + Attributes: expectedSerializedAttrs, + }, nil + } + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: testutils.WorkflowExecutionConfigSample, + }, + }, response)) + + // unrecognized errors are thrown + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) (models.Resource, error) { + + return models.Resource{}, errors.NewFlyteAdminErrorf(5323, "random code") + } + _, err = manager.GetProjectAttributes(context.Background(), request) + assert.Error(t, err) +} + +func TestGetProjectAttributes_ConfigLookup(t *testing.T) { + request := admin.ProjectAttributesGetRequest{ + Project: project, + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).GetFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) (models.Resource, error) { + // return not found to trigger loading from config + return models.Resource{}, errors.NewFlyteAdminError(codes.NotFound, "not found message") + } + config := runtimeMocks.MockApplicationProvider{} + manager := NewResourceManager(db, &config) + + t.Run("config 1", func(t *testing.T) { + appConfig := runtimeInterfaces.ApplicationConfig{ + MaxParallelism: 3, + K8SServiceAccount: "testserviceaccount", + Labels: map[string]string{"lab1": "name"}, + OutputLocationPrefix: "s3://test-bucket", + } + config.SetTopLevelConfig(appConfig) + + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ + MaxParallelism: 3, + SecurityContext: &core.SecurityContext{ + RunAs: &core.Identity{K8SServiceAccount: "testserviceaccount"}, + }, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://test-bucket", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "name"}, + }, + }, + }, + }, + }, + }, response)) + }) + + t.Run("config 2", func(t *testing.T) { + appConfig := runtimeInterfaces.ApplicationConfig{ + MaxParallelism: 3, + AssumableIamRole: "myrole", + } + config.SetTopLevelConfig(appConfig) + + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ + MaxParallelism: 3, + SecurityContext: &core.SecurityContext{ + RunAs: &core.Identity{IamRole: "myrole"}, + }, + }, + }, + }, + }, + }, response)) + }) + + t.Run("config 3", func(t *testing.T) { + appConfig := runtimeInterfaces.ApplicationConfig{ + MaxParallelism: 3, + Annotations: map[string]string{"ann1": "val1"}, + } + config.SetTopLevelConfig(appConfig) + + response, err := manager.GetProjectAttributes(context.Background(), request) + assert.Nil(t, err) + assert.True(t, proto.Equal(&admin.ProjectAttributesGetResponse{ + Attributes: &admin.ProjectAttributes{ + Project: project, + MatchingAttributes: &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ + MaxParallelism: 3, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "val1"}, + }, + }, + }, + }, + }, + }, response)) + }) +} + +func TestDeleteProjectAttributes(t *testing.T) { + request := admin.ProjectAttributesDeleteRequest{ + Project: project, + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, + } + db := mocks.NewMockRepository() + db.ResourceRepo().(*mocks.MockResourceRepo).DeleteFunction = func( + ctx context.Context, ID repoInterfaces.ResourceID) error { + assert.Equal(t, project, ID.Project) + assert.Equal(t, "", ID.Domain) + assert.Equal(t, admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG.String(), ID.ResourceType) + return nil + } + manager := NewResourceManager(db, testutils.GetApplicationConfigWithDefaultDomains()) + _, err := manager.DeleteProjectAttributes(context.Background(), request) + assert.Nil(t, err) +} + func TestGetResource(t *testing.T) { request := interfaces.ResourceRequest{ Project: project, diff --git a/pkg/manager/impl/shared/iface.go b/pkg/manager/impl/shared/iface.go new file mode 100644 index 000000000..8cf83882b --- /dev/null +++ b/pkg/manager/impl/shared/iface.go @@ -0,0 +1,25 @@ +package shared + +import ( + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" + "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/core" + "github.com/golang/protobuf/ptypes/wrappers" +) + +// WorkflowExecutionConfigInterface is used as common interface for capturing the common behavior catering to the needs +// of fetching the WorkflowExecutionConfig across LaunchPlanSpec, ExecutionCreateRequest +// MatchableResource_WORKFLOW_EXECUTION_CONFIG and ApplicationConfig +type WorkflowExecutionConfigInterface interface { + // GetMaxParallelism Can be used to control the number of parallel nodes to run within the workflow. This is useful to achieve fairness. + GetMaxParallelism() int32 + // GetRawOutputDataConfig Encapsulates user settings pertaining to offloaded data (i.e. Blobs, Schema, query data, etc.). + GetRawOutputDataConfig() *admin.RawOutputDataConfig + // GetSecurityContext Indicates security context permissions for executions triggered with this matchable attribute. + GetSecurityContext() *core.SecurityContext + // GetAnnotations Custom annotations to be applied to a triggered execution resource. + GetAnnotations() *admin.Annotations + // GetLabels Custom labels to be applied to a triggered execution resource. + GetLabels() *admin.Labels + // GetInterruptible indicates a workflow should be flagged as interruptible for a single execution. If omitted, the workflow's default is used. + GetInterruptible() *wrappers.BoolValue +} diff --git a/pkg/manager/impl/testutils/attributes.go b/pkg/manager/impl/testutils/attributes.go index 83b7a8d67..92d276d3e 100644 --- a/pkg/manager/impl/testutils/attributes.go +++ b/pkg/manager/impl/testutils/attributes.go @@ -11,3 +11,18 @@ var ExecutionQueueAttributes = &admin.MatchingAttributes{ }, }, } + +var WorkflowExecutionConfigSample = &admin.MatchingAttributes{ + Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ + WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ + MaxParallelism: 5, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://test-bucket", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "val1"}, + }, + Annotations: nil, + }, + }, +} diff --git a/pkg/manager/impl/util/shared.go b/pkg/manager/impl/util/shared.go index 3758775fa..8bf449af8 100644 --- a/pkg/manager/impl/util/shared.go +++ b/pkg/manager/impl/util/shared.go @@ -254,3 +254,42 @@ func GetMatchableResource(ctx context.Context, resourceManager interfaces.Resour } return matchableResource, nil } + +// MergeIntoExecConfig into workflowExecConfig (higher priority) from spec (lower priority) and return the +// a new object with the merged changes. +// After settings project is done, can move this function back to execution manager. Currently shared with resource. +func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec shared.WorkflowExecutionConfigInterface) admin.WorkflowExecutionConfig { + if workflowExecConfig.GetMaxParallelism() == 0 && spec.GetMaxParallelism() > 0 { + workflowExecConfig.MaxParallelism = spec.GetMaxParallelism() + } + + // Do a deep check on the spec in case the security context is set but to an empty object (which may be + // the case when coming from the UI) + if workflowExecConfig.GetSecurityContext() == nil && spec.GetSecurityContext() != nil { + if spec.GetSecurityContext().GetRunAs() != nil && + (len(spec.GetSecurityContext().GetRunAs().GetK8SServiceAccount()) > 0 || + len(spec.GetSecurityContext().GetRunAs().GetIamRole()) > 0) { + workflowExecConfig.SecurityContext = spec.GetSecurityContext() + } + } + // Launchplan spec has label, annotation and rawOutputDataConfig initialized with empty values. + // Hence we do a deep check in the following conditions before assignment + if (workflowExecConfig.GetRawOutputDataConfig() == nil || + len(workflowExecConfig.GetRawOutputDataConfig().GetOutputLocationPrefix()) == 0) && + (spec.GetRawOutputDataConfig() != nil && len(spec.GetRawOutputDataConfig().OutputLocationPrefix) > 0) { + workflowExecConfig.RawOutputDataConfig = spec.GetRawOutputDataConfig() + } + if (workflowExecConfig.GetLabels() == nil || len(workflowExecConfig.GetLabels().Values) == 0) && + (spec.GetLabels() != nil && len(spec.GetLabels().Values) > 0) { + workflowExecConfig.Labels = spec.GetLabels() + } + if (workflowExecConfig.GetAnnotations() == nil || len(workflowExecConfig.GetAnnotations().Values) == 0) && + (spec.GetAnnotations() != nil && len(spec.GetAnnotations().Values) > 0) { + workflowExecConfig.Annotations = spec.GetAnnotations() + } + + if workflowExecConfig.GetInterruptible() == nil && spec.GetInterruptible() != nil { + workflowExecConfig.Interruptible = spec.GetInterruptible() + } + return workflowExecConfig +} diff --git a/pkg/manager/impl/util/shared_test.go b/pkg/manager/impl/util/shared_test.go index 6885c9ee3..56b658332 100644 --- a/pkg/manager/impl/util/shared_test.go +++ b/pkg/manager/impl/util/shared_test.go @@ -3,9 +3,12 @@ package util import ( "context" "errors" + "fmt" "strings" "testing" + "github.com/golang/protobuf/ptypes/wrappers" + "github.com/flyteorg/flyteadmin/pkg/common" commonMocks "github.com/flyteorg/flyteadmin/pkg/common/mocks" flyteAdminErrors "github.com/flyteorg/flyteadmin/pkg/errors" @@ -564,3 +567,137 @@ func TestGetMatchableResource(t *testing.T) { assert.NotNil(t, err) }) } + +func TestMergeIntoExecConfig(t *testing.T) { + var res admin.WorkflowExecutionConfig + parameters := []struct { + higher, lower, expected admin.WorkflowExecutionConfig + }{ + // Max Parallelism taken from higher + { + admin.WorkflowExecutionConfig{ + MaxParallelism: 5, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://test-bucket", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "val1"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + admin.WorkflowExecutionConfig{ + MaxParallelism: 0, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://asdf", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "oldvalue"}, + }, + }, + admin.WorkflowExecutionConfig{ + MaxParallelism: 5, + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://test-bucket", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "val1"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + }, + + // Values that are set to empty in higher priority get overwritten + { + admin.WorkflowExecutionConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "", + }, + Labels: &admin.Labels{ + Values: map[string]string{}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{}, + }, + }, + admin.WorkflowExecutionConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://asdf", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "oldvalue"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + admin.WorkflowExecutionConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://asdf", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "oldvalue"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + }, + + // Values that are not set at all get merged in + { + admin.WorkflowExecutionConfig{}, + admin.WorkflowExecutionConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://asdf", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "oldvalue"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + admin.WorkflowExecutionConfig{ + RawOutputDataConfig: &admin.RawOutputDataConfig{ + OutputLocationPrefix: "s3://asdf", + }, + Labels: &admin.Labels{ + Values: map[string]string{"lab1": "oldvalue"}, + }, + Annotations: &admin.Annotations{ + Values: map[string]string{"ann1": "annval"}, + }, + }, + }, + + // Interruptible + { + admin.WorkflowExecutionConfig{ + Interruptible: &wrappers.BoolValue{ + Value: false, + }, + }, + admin.WorkflowExecutionConfig{ + Interruptible: &wrappers.BoolValue{ + Value: true, + }, + }, + admin.WorkflowExecutionConfig{ + Interruptible: &wrappers.BoolValue{ + Value: false, + }, + }, + }, + } + + for i := range parameters { + t.Run(fmt.Sprintf("Testing [%v]", i), func(t *testing.T) { + res = MergeIntoExecConfig(parameters[i].higher, ¶meters[i].lower) + assert.True(t, proto.Equal(¶meters[i].expected, &res)) + }) + } +} diff --git a/pkg/manager/impl/validation/attributes_validator.go b/pkg/manager/impl/validation/attributes_validator.go index d0ad15a2e..9a123acd8 100644 --- a/pkg/manager/impl/validation/attributes_validator.go +++ b/pkg/manager/impl/validation/attributes_validator.go @@ -53,6 +53,21 @@ func ValidateProjectDomainAttributesUpdateRequest(ctx context.Context, fmt.Sprintf("%s-%s", request.Attributes.Project, request.Attributes.Domain)) } +func ValidateProjectAttributesUpdateRequest(ctx context.Context, + db repositoryInterfaces.Repository, + request admin.ProjectAttributesUpdateRequest) ( + admin.MatchableResource, error) { + + if request.Attributes == nil { + return defaultMatchableResource, shared.GetMissingArgumentError(shared.Attributes) + } + if err := ValidateProjectForUpdate(ctx, db, request.Attributes.Project); err != nil { + return defaultMatchableResource, err + } + + return validateMatchingAttributes(request.Attributes.MatchingAttributes, request.Attributes.Project) +} + func ValidateProjectDomainAttributesGetRequest(ctx context.Context, db repositoryInterfaces.Repository, config runtimeInterfaces.ApplicationConfiguration, request admin.ProjectDomainAttributesGetRequest) error { if err := ValidateProjectAndDomain(ctx, db, config, request.Project, request.Domain); err != nil { diff --git a/pkg/manager/impl/validation/project_validator.go b/pkg/manager/impl/validation/project_validator.go index c5795ef5c..03b427640 100644 --- a/pkg/manager/impl/validation/project_validator.go +++ b/pkg/manager/impl/validation/project_validator.go @@ -81,3 +81,33 @@ func ValidateProjectAndDomain( } return nil } + +func ValidateProjectForUpdate( + ctx context.Context, db repositoryInterfaces.Repository, projectID string) error { + + project, err := db.ProjectRepo().Get(ctx, projectID) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "failed to validate that project [%s] is registered, err: [%+v]", + projectID, err) + } + if *project.State != int32(admin.Project_ACTIVE) { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "project [%s] is not active", projectID) + } + return nil +} + +// ValidateProjectExists doesn't check that the project is active. This is used to get Project level attributes, which you should +// be able to do even for inactive projects. +func ValidateProjectExists( + ctx context.Context, db repositoryInterfaces.Repository, projectID string) error { + + _, err := db.ProjectRepo().Get(ctx, projectID) + if err != nil { + return errors.NewFlyteAdminErrorf(codes.InvalidArgument, + "failed to validate that project [%s] exists, err: [%+v]", + projectID, err) + } + return nil +} diff --git a/pkg/manager/impl/validation/project_validator_test.go b/pkg/manager/impl/validation/project_validator_test.go index d53ac3d28..521d25085 100644 --- a/pkg/manager/impl/validation/project_validator_test.go +++ b/pkg/manager/impl/validation/project_validator_test.go @@ -304,3 +304,63 @@ func TestValidateProjectAndDomainNotFound(t *testing.T) { "flyte-project", "domain") assert.EqualError(t, err, "failed to validate that project [flyte-project] and domain [domain] are registered, err: [project [flyte-project] not found]") } + +func TestValidateProjectDb(t *testing.T) { + mockRepo := repositoryMocks.NewMockRepository() + t.Run("base case", func(t *testing.T) { + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + assert.Equal(t, projectID, "flyte-project-id") + activeState := int32(admin.Project_ACTIVE) + return models.Project{State: &activeState}, nil + } + err := ValidateProjectForUpdate(context.Background(), mockRepo, "flyte-project-id") + + assert.Nil(t, err) + }) + + t.Run("error getting", func(t *testing.T) { + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + + return models.Project{}, errors.New("missing") + } + err := ValidateProjectForUpdate(context.Background(), mockRepo, "flyte-project-id") + assert.Error(t, err) + }) + + t.Run("error archived", func(t *testing.T) { + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + state := int32(admin.Project_ARCHIVED) + return models.Project{State: &state}, nil + } + err := ValidateProjectForUpdate(context.Background(), mockRepo, "flyte-project-id") + assert.Error(t, err) + }) +} + +func TestValidateProjectExistsDb(t *testing.T) { + mockRepo := repositoryMocks.NewMockRepository() + t.Run("base case", func(t *testing.T) { + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + assert.Equal(t, projectID, "flyte-project-id") + activeState := int32(admin.Project_ACTIVE) + return models.Project{State: &activeState}, nil + } + err := ValidateProjectExists(context.Background(), mockRepo, "flyte-project-id") + + assert.Nil(t, err) + }) + + t.Run("error getting", func(t *testing.T) { + mockRepo.ProjectRepo().(*repositoryMocks.MockProjectRepo).GetFunction = func( + ctx context.Context, projectID string) (models.Project, error) { + + return models.Project{}, errors.New("missing") + } + err := ValidateProjectExists(context.Background(), mockRepo, "flyte-project-id") + assert.Error(t, err) + }) +} diff --git a/pkg/manager/interfaces/resource.go b/pkg/manager/interfaces/resource.go index eb1724609..42d4a3c9b 100644 --- a/pkg/manager/interfaces/resource.go +++ b/pkg/manager/interfaces/resource.go @@ -6,12 +6,19 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" ) -// Interface for managing project, domain and workflow -specific attributes. +// ResourceInterface manages project, domain and workflow -specific attributes. type ResourceInterface interface { ListAll(ctx context.Context, request admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) GetResource(ctx context.Context, request ResourceRequest) (*ResourceResponse, error) + UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) + GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) + DeleteProjectAttributes(ctx context.Context, request admin.ProjectAttributesDeleteRequest) ( + *admin.ProjectAttributesDeleteResponse, error) + UpdateProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) GetProjectDomainAttributes(ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( diff --git a/pkg/manager/mocks/resource.go b/pkg/manager/mocks/resource.go index dada0e92a..1339e10d7 100644 --- a/pkg/manager/mocks/resource.go +++ b/pkg/manager/mocks/resource.go @@ -8,6 +8,13 @@ import ( "github.com/flyteorg/flyteidl/gen/pb-go/flyteidl/admin" ) +type UpdateProjectAttrsFunc func(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) +type GetProjectAttrFunc func(ctx context.Context, request admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) +type DeleteProjectAttrFunc func(ctx context.Context, request admin.ProjectAttributesDeleteRequest) ( + *admin.ProjectAttributesDeleteResponse, error) + type UpdateProjectDomainFunc func(ctx context.Context, request admin.ProjectDomainAttributesUpdateRequest) ( *admin.ProjectDomainAttributesUpdateResponse, error) type GetProjectDomainFunc func(ctx context.Context, request admin.ProjectDomainAttributesGetRequest) ( @@ -24,6 +31,9 @@ type MockResourceManager struct { DeleteFunc DeleteProjectDomainFunc ListFunc ListResourceFunc GetResourceFunc GetResourceFunc + updateProjectAttrsFunc UpdateProjectAttrsFunc + getProjectAttrFunc GetProjectAttrFunc + deleteProjectAttrFunc DeleteProjectAttrFunc } func (m *MockResourceManager) GetResource(ctx context.Context, request interfaces.ResourceRequest) (*interfaces.ResourceResponse, error) { @@ -79,6 +89,42 @@ func (m *MockResourceManager) DeleteProjectDomainAttributes( return nil, nil } +func (m *MockResourceManager) SetUpdateProjectAttributes(updateProjectAttrsFunc UpdateProjectAttrsFunc) { + m.updateProjectAttrsFunc = updateProjectAttrsFunc +} + +func (m *MockResourceManager) UpdateProjectAttributes(ctx context.Context, request admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) { + if m.updateProjectAttrsFunc != nil { + return m.updateProjectAttrsFunc(ctx, request) + } + return nil, nil +} + +func (m *MockResourceManager) SetGetProjectAttributes(getProjectFunc GetProjectAttrFunc) { + m.getProjectAttrFunc = getProjectFunc +} + +func (m *MockResourceManager) GetProjectAttributes(ctx context.Context, request admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) { + if m.getProjectAttrFunc != nil { + return m.getProjectAttrFunc(ctx, request) + } + return nil, nil +} + +func (m *MockResourceManager) SetDeleteProjectAttributes(deleteProjectFunc DeleteProjectAttrFunc) { + m.deleteProjectAttrFunc = deleteProjectFunc +} + +func (m *MockResourceManager) DeleteProjectAttributes(ctx context.Context, request admin.ProjectAttributesDeleteRequest) ( + *admin.ProjectAttributesDeleteResponse, error) { + if m.deleteProjectAttrFunc != nil { + return m.deleteProjectAttrFunc(ctx, request) + } + return nil, nil +} + func (m *MockResourceManager) ListAll(ctx context.Context, request admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) { if m.ListFunc != nil { diff --git a/pkg/repositories/gormimpl/resource_repo.go b/pkg/repositories/gormimpl/resource_repo.go index e4b3f5dfd..1f94abce5 100644 --- a/pkg/repositories/gormimpl/resource_repo.go +++ b/pkg/repositories/gormimpl/resource_repo.go @@ -38,7 +38,10 @@ The data in the Resource repo maps to the following rules: ** Example: Domain="staging" Project="Lyft" Workflow="" LaunchPlan= "l1" is invalid. */ func validateCreateOrUpdateResourceInput(project, domain, workflow, launchPlan, resourceType string) bool { - if domain == "" || resourceType == "" { + if resourceType == "" { + return false + } + if domain == "" && project == "" { return false } if project == "" && (workflow != "" || launchPlan != "") { @@ -82,6 +85,7 @@ func (r *ResourceRepo) CreateOrUpdate(ctx context.Context, input models.Resource return nil } +// Get returns the most-specific attribute setting for the given ResourceType. func (r *ResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { if !validateCreateOrUpdateResourceInput(ID.Project, ID.Domain, ID.Workflow, ID.LaunchPlan, ID.ResourceType) { return models.Resource{}, r.errorTransformer.ToFlyteAdminError(flyteAdminDbErrors.GetInvalidInputError(fmt.Sprintf("%v", ID))) @@ -89,12 +93,17 @@ func (r *ResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) (model var resources []models.Resource timer := r.metrics.GetDuration.Start() - txWhereClause := "resource_type = ? AND domain = ? AND project IN (?) AND workflow IN (?) AND launch_plan IN (?)" + txWhereClause := "resource_type = ? AND domain IN (?) AND project IN (?) AND workflow IN (?) AND launch_plan IN (?)" project := []string{""} if ID.Project != "" { project = append(project, ID.Project) } + domain := []string{""} + if ID.Domain != "" { + domain = append(domain, ID.Domain) + } + workflow := []string{""} if ID.Workflow != "" { workflow = append(workflow, ID.Workflow) @@ -105,7 +114,33 @@ func (r *ResourceRepo) Get(ctx context.Context, ID interfaces.ResourceID) (model launchPlan = append(launchPlan, ID.LaunchPlan) } - tx := r.db.Where(txWhereClause, ID.ResourceType, ID.Domain, project, workflow, launchPlan) + tx := r.db.Where(txWhereClause, ID.ResourceType, domain, project, workflow, launchPlan) + tx.Order(priorityDescending).First(&resources) + timer.Stop() + + if (tx.Error != nil && errors.Is(tx.Error, gorm.ErrRecordNotFound)) || len(resources) == 0 { + return models.Resource{}, flyteAdminErrors.NewFlyteAdminErrorf(codes.NotFound, + "Resource [%+v] not found", ID) + } else if tx.Error != nil { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(tx.Error) + } + return resources[0], nil +} + +// GetProjectLevel differs from Get in that it returns only the project-level attribute setting for the +// given ResourceType if it exists. The reason this exists is because we want to return project level +// attributes to Flyte Console, regardless of whether a more specific setting exists. +func (r *ResourceRepo) GetProjectLevel(ctx context.Context, ID interfaces.ResourceID) (models.Resource, error) { + if ID.Project == "" { + return models.Resource{}, r.errorTransformer.ToFlyteAdminError(flyteAdminDbErrors.GetInvalidInputError(fmt.Sprintf("%v", ID))) + } + + var resources []models.Resource + timer := r.metrics.GetDuration.Start() + + txWhereClause := "resource_type = ? AND domain = '' AND project = ? AND workflow = '' AND launch_plan = ''" + + tx := r.db.Where(txWhereClause, ID.ResourceType, ID.Project) tx.Order(priorityDescending).First(&resources) timer.Stop() diff --git a/pkg/repositories/gormimpl/resource_repo_test.go b/pkg/repositories/gormimpl/resource_repo_test.go index 00ecb7d41..1ea145d9d 100644 --- a/pkg/repositories/gormimpl/resource_repo_test.go +++ b/pkg/repositories/gormimpl/resource_repo_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" + "gorm.io/gorm" + "github.com/flyteorg/flyteadmin/pkg/repositories/interfaces" mocket "github.com/Selvatico/go-mocket" @@ -14,6 +16,7 @@ import ( ) const resourceTestWorkflowName = "workflow" +const resourceTypeStr = "resource-type" func TestCreateWorkflowAttributes(t *testing.T) { resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) @@ -86,11 +89,11 @@ func TestGetWorkflowAttributes(t *testing.T) { response["project"] = "project" response["domain"] = "domain" response["workflow"] = resourceTestWorkflowName - response["resource_type"] = "resource-type" + response["resource_type"] = resourceTypeStr response["attributes"] = []byte("attrs") query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain = $2 AND project IN ($3,$4) AND workflow IN ($5,$6) AND launch_plan IN ($7) ORDER BY priority desc,"resources"."id" LIMIT 1`).WithReply( + query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain IN ($2,$3) AND project IN ($4,$5) AND workflow IN ($6,$7) AND launch_plan IN ($8) ORDER BY priority desc,"resources"."id" LIMIT 1`).WithReply( []map[string]interface{}{ response, }) @@ -100,7 +103,7 @@ func TestGetWorkflowAttributes(t *testing.T) { assert.Equal(t, "project", output.Project) assert.Equal(t, "domain", output.Domain) assert.Equal(t, "workflow", output.Workflow) - assert.Equal(t, "resource-type", output.ResourceType) + assert.Equal(t, resourceTypeStr, output.ResourceType) assert.Equal(t, []byte("attrs"), output.Attributes) } @@ -111,11 +114,11 @@ func TestProjectDomainAttributes(t *testing.T) { response := make(map[string]interface{}) response[project] = project response[domain] = domain - response["resource_type"] = "resource-type" + response["resource_type"] = resourceTypeStr response["attributes"] = []byte("attrs") query := GlobalMock.NewMock() - query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain = $2 AND project IN ($3,$4) AND workflow IN ($5) AND launch_plan IN ($6) ORDER BY priority desc,"resources"."id" LIMIT 1`).WithReply( + query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain IN ($2,$3) AND project IN ($4,$5) AND workflow IN ($6) AND launch_plan IN ($7) ORDER BY priority desc,"resources"."id" LIMIT 1`).WithReply( []map[string]interface{}{ response, }) @@ -129,6 +132,35 @@ func TestProjectDomainAttributes(t *testing.T) { assert.Equal(t, []byte("attrs"), output.Attributes) } +func TestProjectLevelAttributes(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + response := make(map[string]interface{}) + response[project] = project + response[domain] = "" + response["resource_type"] = "resource-type" + response["attributes"] = []byte("attrs") + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain = '' AND project = $2 AND workflow = '' AND launch_plan = '' ORDER BY priority desc,"resources"."id" LIMIT 1`).WithReply( + []map[string]interface{}{ + response, + }) + + output, err := resourceRepo.GetProjectLevel(context.Background(), interfaces.ResourceID{Project: "project", Domain: "", ResourceType: "resource"}) + assert.Nil(t, err) + assert.Equal(t, project, output.Project) + assert.Equal(t, "", output.Domain) + assert.Equal(t, "", output.Workflow) + assert.Equal(t, "resource-type", output.ResourceType) + assert.Equal(t, []byte("attrs"), output.Attributes) + + // Must have a project defined + _, err = resourceRepo.GetProjectLevel(context.Background(), interfaces.ResourceID{Project: "", Domain: "", ResourceType: "resource"}) + assert.Error(t, err) +} + func TestGetRawWorkflowAttributes(t *testing.T) { resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) GlobalMock := mocket.Catcher.Reset() @@ -198,3 +230,18 @@ func TestListAll(t *testing.T) { assert.Equal(t, []byte("attrs"), output[0].Attributes) assert.True(t, fakeResponse.Triggered) } + +func TestGetError(t *testing.T) { + resourceRepo := NewResourceRepo(GetDbForTest(t), errors.NewTestErrorTransformer(), mockScope.NewTestScope()) + GlobalMock := mocket.Catcher.Reset() + GlobalMock.Logging = true + + query := GlobalMock.NewMock() + query.WithQuery(`SELECT * FROM "resources" WHERE resource_type = $1 AND domain IN ($2,$3) AND project IN ($4,$5) AND workflow IN ($6,$7) AND launch_plan IN ($8) ORDER BY priority desc,"resources"."id" LIMIT 1`).WithError(gorm.ErrRecordNotFound) + + output, err := resourceRepo.Get(context.Background(), interfaces.ResourceID{Project: "project", Domain: "domain", Workflow: "workflow", ResourceType: "resource"}) + assert.Error(t, err) + assert.Equal(t, "", output.Project) + assert.Equal(t, "", output.Domain) + assert.Equal(t, "", output.Workflow) +} diff --git a/pkg/repositories/interfaces/resource_repo.go b/pkg/repositories/interfaces/resource_repo.go index b434c8f71..5ffbe0127 100644 --- a/pkg/repositories/interfaces/resource_repo.go +++ b/pkg/repositories/interfaces/resource_repo.go @@ -13,6 +13,9 @@ type ResourceRepoInterface interface { Get(ctx context.Context, ID ResourceID) (models.Resource, error) // Returns a matching Type model. GetRaw(ctx context.Context, ID ResourceID) (models.Resource, error) + // GetProjectLevel returns the Project level resource entry, if any, even if there is a higher + // specificity resource. + GetProjectLevel(ctx context.Context, ID ResourceID) (models.Resource, error) // Lists all resources ListAll(ctx context.Context, resourceType string) ([]models.Resource, error) // Deletes a matching Type model when it exists. diff --git a/pkg/repositories/mocks/resource.go b/pkg/repositories/mocks/resource.go index 9d0962a1e..0def6351e 100644 --- a/pkg/repositories/mocks/resource.go +++ b/pkg/repositories/mocks/resource.go @@ -43,6 +43,14 @@ func (r *MockResourceRepo) GetRaw(ctx context.Context, ID interfaces.ResourceID) return models.Resource{}, nil } +func (r *MockResourceRepo) GetProjectLevel(ctx context.Context, ID interfaces.ResourceID) ( + models.Resource, error) { + if r.GetFunction != nil { + return r.GetFunction(ctx, ID) + } + return models.Resource{}, nil +} + func (r *MockResourceRepo) ListAll(ctx context.Context, resourceType string) ([]models.Resource, error) { if r.ListAllFunction != nil { return r.ListAllFunction(ctx, resourceType) diff --git a/pkg/repositories/models/resource.go b/pkg/repositories/models/resource.go index 02da2fbe3..01c570296 100644 --- a/pkg/repositories/models/resource.go +++ b/pkg/repositories/models/resource.go @@ -5,7 +5,7 @@ import "time" type ResourcePriority int32 const ( - ResourcePriorityDomainLevel ResourcePriority = 1 + ResourcePriorityProjectLevel ResourcePriority = 5 // use this ResourcePriorityProjectDomainLevel ResourcePriority = 10 ResourcePriorityWorkflowLevel ResourcePriority = 100 ResourcePriorityLaunchPlanLevel ResourcePriority = 1000 diff --git a/pkg/repositories/transformers/resource.go b/pkg/repositories/transformers/resource.go index e24e15749..f59e44593 100644 --- a/pkg/repositories/transformers/resource.go +++ b/pkg/repositories/transformers/resource.go @@ -112,8 +112,23 @@ func ProjectDomainAttributesToResourceModel(attributes admin.ProjectDomainAttrib }, nil } -func MergeUpdateProjectDomainAttributes(ctx context.Context, model models.Resource, resource admin.MatchableResource, - resourceID *repoInterfaces.ResourceID, attributes *admin.ProjectDomainAttributes) (models.Resource, error) { +func ProjectAttributesToResourceModel(attributes admin.ProjectAttributes, resource admin.MatchableResource) (models.Resource, error) { + attributeBytes, err := proto.Marshal(attributes.MatchingAttributes) + if err != nil { + return models.Resource{}, err + } + return models.Resource{ + Project: attributes.Project, + ResourceType: resource.String(), + Priority: models.ResourcePriorityProjectLevel, + Attributes: attributeBytes, + }, nil +} + +// MergeUpdatePluginAttributes only handles plugin overrides. Other attributes are just overridden when an +// update happens. +func MergeUpdatePluginAttributes(ctx context.Context, model models.Resource, resource admin.MatchableResource, + resourceID *repoInterfaces.ResourceID, matchingAttributes *admin.MatchingAttributes) (models.Resource, error) { switch resource { case admin.MatchableResource_PLUGIN_OVERRIDE: var existingAttributes admin.MatchingAttributes @@ -122,7 +137,7 @@ func MergeUpdateProjectDomainAttributes(ctx context.Context, model models.Resour return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, "Unable to unmarshal existing resource attributes for [%+v] with err: %v", resourceID, err) } - updatedAttributes := mergeUpdatePluginOverrides(existingAttributes, attributes.GetMatchingAttributes()) + updatedAttributes := mergeUpdatePluginOverrides(existingAttributes, matchingAttributes) marshaledAttributes, err := proto.Marshal(updatedAttributes) if err != nil { return models.Resource{}, errors.NewFlyteAdminErrorf(codes.Internal, diff --git a/pkg/repositories/transformers/resource_test.go b/pkg/repositories/transformers/resource_test.go index c03437792..1a7a32b1b 100644 --- a/pkg/repositories/transformers/resource_test.go +++ b/pkg/repositories/transformers/resource_test.go @@ -86,15 +86,13 @@ func TestMergeUpdateProjectDomainAttributes(t *testing.T) { ResourceType: "PLUGIN_OVERRIDE", Attributes: existingWorkflowAttributes, } - mergeUpdatedModel, err := MergeUpdateProjectDomainAttributes(context.Background(), existingModel, - admin.MatchableResource_PLUGIN_OVERRIDE, &repoInterfaces.ResourceID{}, &admin.ProjectDomainAttributes{ - Project: resourceProject, - Domain: resourceDomain, - MatchingAttributes: testutils.GetPluginOverridesAttributes(map[string][]string{ - "sidecar": {"plugin_c"}, - "hive": {"plugin_d"}, - }), - }) + mergeUpdatedModel, err := MergeUpdatePluginAttributes(context.Background(), existingModel, + admin.MatchableResource_PLUGIN_OVERRIDE, &repoInterfaces.ResourceID{}, + testutils.GetPluginOverridesAttributes(map[string][]string{ + "sidecar": {"plugin_c"}, + "hive": {"plugin_d"}, + }), + ) assert.NoError(t, err) var updatedAttributes admin.MatchingAttributes err = proto.Unmarshal(mergeUpdatedModel.Attributes, &updatedAttributes) @@ -124,8 +122,8 @@ func TestMergeUpdateProjectDomainAttributes(t *testing.T) { Workflow: resourceWorkflow, ResourceType: "PLUGIN_OVERRIDE", } - _, err := MergeUpdateProjectDomainAttributes(context.Background(), existingModel, - admin.MatchableResource_TASK_RESOURCE, &repoInterfaces.ResourceID{}, &admin.ProjectDomainAttributes{}) + _, err := MergeUpdatePluginAttributes(context.Background(), existingModel, + admin.MatchableResource_TASK_RESOURCE, &repoInterfaces.ResourceID{}, &admin.MatchingAttributes{}) assert.Error(t, err, "unsupported resource type") }) } @@ -252,3 +250,20 @@ func TestFromWorkflowAttributesModel_InvalidResourceAttributes(t *testing.T) { assert.NotNil(t, err) assert.Equal(t, codes.Internal, err.(errors.FlyteAdminError).Code()) } + +func TestProjectAttributesToResourceModel(t *testing.T) { + pa := admin.ProjectAttributes{ + Project: resourceProject, + MatchingAttributes: matchingClusterResourceAttributes, + } + rm, err := ProjectAttributesToResourceModel(pa, admin.MatchableResource_CLUSTER_RESOURCE) + + assert.NoError(t, err) + assert.EqualValues(t, models.Resource{ + Project: resourceProject, + Domain: "", + ResourceType: admin.MatchableResource_CLUSTER_RESOURCE.String(), + Priority: models.ResourcePriorityProjectLevel, + Attributes: marshalledClusterResourceAttributes, + }, rm) +} diff --git a/pkg/rpc/adminservice/attributes.go b/pkg/rpc/adminservice/attributes.go index f65df0ed2..7f9a23efe 100644 --- a/pkg/rpc/adminservice/attributes.go +++ b/pkg/rpc/adminservice/attributes.go @@ -117,6 +117,63 @@ func (m *AdminService) DeleteProjectDomainAttributes(ctx context.Context, reques return response, nil } +func (m *AdminService) UpdateProjectAttributes(ctx context.Context, request *admin.ProjectAttributesUpdateRequest) ( + *admin.ProjectAttributesUpdateResponse, error) { + + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ProjectAttributesUpdateResponse + var err error + m.Metrics.projectAttributesEndpointMetrics.get.Time(func() { + response, err = m.ResourceManager.UpdateProjectAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.projectAttributesEndpointMetrics.get) + } + + return response, nil +} + +func (m *AdminService) GetProjectAttributes(ctx context.Context, request *admin.ProjectAttributesGetRequest) ( + *admin.ProjectAttributesGetResponse, error) { + + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ProjectAttributesGetResponse + var err error + m.Metrics.projectAttributesEndpointMetrics.get.Time(func() { + response, err = m.ResourceManager.GetProjectAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.projectAttributesEndpointMetrics.get) + } + + return response, nil +} + +func (m *AdminService) DeleteProjectAttributes(ctx context.Context, request *admin.ProjectAttributesDeleteRequest) ( + *admin.ProjectAttributesDeleteResponse, error) { + + defer m.interceptPanic(ctx, request) + if request == nil { + return nil, status.Errorf(codes.InvalidArgument, "Incorrect request, nil requests not allowed") + } + var response *admin.ProjectAttributesDeleteResponse + var err error + m.Metrics.projectAttributesEndpointMetrics.delete.Time(func() { + response, err = m.ResourceManager.DeleteProjectAttributes(ctx, *request) + }) + if err != nil { + return nil, util.TransformAndRecordError(err, &m.Metrics.projectAttributesEndpointMetrics.delete) + } + + return response, nil +} + func (m *AdminService) ListMatchableAttributes(ctx context.Context, request *admin.ListMatchableAttributesRequest) ( *admin.ListMatchableAttributesResponse, error) { defer m.interceptPanic(ctx, request) diff --git a/pkg/rpc/adminservice/tests/project_domain_test.go b/pkg/rpc/adminservice/tests/project_domain_test.go index df65679f6..cc84bb0bf 100644 --- a/pkg/rpc/adminservice/tests/project_domain_test.go +++ b/pkg/rpc/adminservice/tests/project_domain_test.go @@ -35,3 +35,79 @@ func TestUpdateProjectDomain(t *testing.T) { assert.NoError(t, err) assert.True(t, updateCalled) } + +func TestUpdateProjectAttr(t *testing.T) { + ctx := context.Background() + + mockProjectDomainManager := mocks.MockResourceManager{} + var updateCalled bool + mockProjectDomainManager.SetUpdateProjectAttributes( + func(ctx context.Context, + request admin.ProjectAttributesUpdateRequest) (*admin.ProjectAttributesUpdateResponse, error) { + updateCalled = true + return &admin.ProjectAttributesUpdateResponse{}, nil + }, + ) + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + resourceManager: &mockProjectDomainManager, + }) + + resp, err := mockServer.UpdateProjectAttributes(ctx, &admin.ProjectAttributesUpdateRequest{ + Attributes: &admin.ProjectAttributes{ + Project: "project", + }, + }) + assert.NotNil(t, resp) + assert.NoError(t, err) + assert.True(t, updateCalled) +} + +func TestDeleteProjectAttr(t *testing.T) { + ctx := context.Background() + + mockProjectDomainManager := mocks.MockResourceManager{} + var deleteCalled bool + mockProjectDomainManager.SetDeleteProjectAttributes( + func(ctx context.Context, + request admin.ProjectAttributesDeleteRequest) (*admin.ProjectAttributesDeleteResponse, error) { + deleteCalled = true + return &admin.ProjectAttributesDeleteResponse{}, nil + }, + ) + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + resourceManager: &mockProjectDomainManager, + }) + + resp, err := mockServer.DeleteProjectAttributes(ctx, &admin.ProjectAttributesDeleteRequest{ + Project: "project", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, + }) + assert.NotNil(t, resp) + assert.NoError(t, err) + assert.True(t, deleteCalled) +} + +func TestGetProjectAttr(t *testing.T) { + ctx := context.Background() + + mockProjectDomainManager := mocks.MockResourceManager{} + var getCalled bool + mockProjectDomainManager.SetGetProjectAttributes( + func(ctx context.Context, + request admin.ProjectAttributesGetRequest) (*admin.ProjectAttributesGetResponse, error) { + getCalled = true + return &admin.ProjectAttributesGetResponse{}, nil + }, + ) + mockServer := NewMockAdminServer(NewMockAdminServerInput{ + resourceManager: &mockProjectDomainManager, + }) + + resp, err := mockServer.GetProjectAttributes(ctx, &admin.ProjectAttributesGetRequest{ + Project: "project", + ResourceType: admin.MatchableResource_WORKFLOW_EXECUTION_CONFIG, + }) + assert.NotNil(t, resp) + assert.NoError(t, err) + assert.True(t, getCalled) +} diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 2f99091e2..21ca5fd7e 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -158,6 +158,31 @@ func (a *ApplicationConfig) GetInterruptible() *wrappers.BoolValue { } } +// GetAsWorkflowExecutionConfig returns the WorkflowExecutionConfig as extracted from this object +func (a *ApplicationConfig) GetAsWorkflowExecutionConfig() admin.WorkflowExecutionConfig { + // These two should always be set, one is a number, and the other returns nil when empty. + wec := admin.WorkflowExecutionConfig{ + MaxParallelism: a.GetMaxParallelism(), + Interruptible: a.GetInterruptible(), + } + + // For the others, we only add the field when the field is set in the config. + if a.GetSecurityContext().RunAs.GetK8SServiceAccount() != "" || a.GetSecurityContext().RunAs.GetIamRole() != "" { + wec.SecurityContext = a.GetSecurityContext() + } + if a.GetRawOutputDataConfig().OutputLocationPrefix != "" { + wec.RawOutputDataConfig = a.GetRawOutputDataConfig() + } + if len(a.GetLabels().Values) > 0 { + wec.Labels = a.GetLabels() + } + if len(a.GetAnnotations().Values) > 0 { + wec.Annotations = a.GetAnnotations() + } + + return wec +} + // This section holds common config for AWS type AWSConfig struct { Region string `json:"region"`