From 4cc23810f1a4a8755d6909985b35e138d6d80373 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nick=20M=C3=BCller?= Date: Mon, 14 Nov 2022 23:17:22 +0100 Subject: [PATCH] Skipping of cached task outputs via execution config (#482) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implemented SkipCache handling for execution config Signed-off-by: Nick Müller * Added cache skip override to RelaunchExecution endpoint Signed-off-by: Nick Müller * Updated to latest version of flytepropeller Signed-off-by: Nick Müller * Renamed SkipCache flag to OverwriteCache Updated to latest released versions of flyteidl, flyteplugins and flytestdlib Updated to latest unmerged version of flytepropeller Signed-off-by: Nick Müller * Updated flyteidl, flytepropeller and flytestdlib to latest released versions Signed-off-by: Nick Müller * Reworded comment for clarity Signed-off-by: Nick Müller Signed-off-by: Nick Müller --- go.mod | 16 +- go.sum | 33 +- pkg/common/mocks/storage.go | 7 +- pkg/manager/impl/execution_manager.go | 1 + pkg/manager/impl/execution_manager_test.go | 428 ++++++++++++++++++ pkg/manager/impl/shared/iface.go | 2 + pkg/manager/impl/util/shared.go | 7 +- .../interfaces/application_configuration.go | 13 +- pkg/workflowengine/impl/prepare_execution.go | 3 + .../impl/prepare_execution_test.go | 8 + 10 files changed, 492 insertions(+), 26 deletions(-) diff --git a/go.mod b/go.mod index 4270a8a58..b42085d07 100644 --- a/go.mod +++ b/go.mod @@ -13,10 +13,10 @@ require ( github.com/cloudevents/sdk-go/v2 v2.8.0 github.com/coreos/go-oidc v2.2.1+incompatible github.com/evanphx/json-patch v4.12.0+incompatible - github.com/flyteorg/flyteidl v1.2.0 - github.com/flyteorg/flyteplugins v1.0.10 - github.com/flyteorg/flytepropeller v1.1.28 - github.com/flyteorg/flytestdlib v1.0.5 + github.com/flyteorg/flyteidl v1.2.5 + github.com/flyteorg/flyteplugins v1.0.18 + github.com/flyteorg/flytepropeller v1.1.47 + github.com/flyteorg/flytestdlib v1.0.12 github.com/flyteorg/stow v0.3.6 github.com/ghodss/yaml v1.0.0 github.com/go-gormigrate/gormigrate/v2 v2.0.0 @@ -168,12 +168,12 @@ require ( go.opencensus.io v0.23.0 // indirect golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect - golang.org/x/net v0.0.0-20220607020251-c690dde0001d // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect - golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68 // indirect + golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 // indirect + golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 // indirect golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 // indirect golang.org/x/text v0.3.7 // indirect - golang.org/x/tools v0.1.11 // indirect + golang.org/x/tools v0.1.12 // indirect golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f // indirect google.golang.org/appengine v1.6.7 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 4d8ad98cd..4bd8bb2f9 100644 --- a/go.sum +++ b/go.sum @@ -352,15 +352,15 @@ github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8SPQ= github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/flyteorg/flyteidl v1.2.0 h1:snJPpc5a5Gr4GXYiAMX6Io1edT91ZxN/7oE6uhydrvk= -github.com/flyteorg/flyteidl v1.2.0/go.mod h1:f0AFl7RFycH7+JLq2th0ReH7v+Xse+QTw4jGdIxiS8I= -github.com/flyteorg/flyteplugins v1.0.10 h1:XBycM4aOSE/WlI8iP9vqogKGXy4FMfVCUUfzxJus/p4= -github.com/flyteorg/flyteplugins v1.0.10/go.mod h1:GfbmRByI/rSatm/Epoj3bNyrXwIQ9NOXTVwLS6Z0p84= -github.com/flyteorg/flytepropeller v1.1.28 h1:68qQ0QRHoCzagF0oifkW/c4A1L4B4LdgyHCPLKMiY2g= -github.com/flyteorg/flytepropeller v1.1.28/go.mod h1:QE3szUWkFnyFg3mMxpn3y93ZSs18T+1SQtVgNhcEMvA= +github.com/flyteorg/flyteidl v1.2.5 h1:oPs0PX9opR9JtWjP5ZH2YMChkbGGL45PIy+90FlaxYc= +github.com/flyteorg/flyteidl v1.2.5/go.mod h1:OJAq333OpInPnMhvVz93AlEjmlQ+t0FAD4aakIYE4OU= +github.com/flyteorg/flyteplugins v1.0.18 h1:DOyxAFaS4luv7H9XRKUpHbO09imsG4LP8Du515FGXyM= +github.com/flyteorg/flyteplugins v1.0.18/go.mod h1:ZbZVBxEWh8Icj1AgfNKg0uPzHHGd9twa4eWcY2Yt6xE= +github.com/flyteorg/flytepropeller v1.1.47 h1:k+moR+YGOyKJnYHDZjBBXvwnuZJ7IhK/PRv/9Ak/QIs= +github.com/flyteorg/flytepropeller v1.1.47/go.mod h1:vZlQTBOsddrNGxmA0To+B2ld3VFg6sRWwcC4KU7+g9A= github.com/flyteorg/flytestdlib v1.0.0/go.mod h1:QSVN5wIM1lM9d60eAEbX7NwweQXW96t5x4jbyftn89c= -github.com/flyteorg/flytestdlib v1.0.5 h1:80A/vfpAJl+pgU6vxccbsYApZPrvyGhOIsCAFngsjnk= -github.com/flyteorg/flytestdlib v1.0.5/go.mod h1:WTe0k3DmmrKFjj3hwiIbjjdCK89X63MBzBbXhQ4Yxf0= +github.com/flyteorg/flytestdlib v1.0.12 h1:A+yN5TX/SezjCjzv/JV29SzlBAyKGeLDOfAiYqzrKcw= +github.com/flyteorg/flytestdlib v1.0.12/go.mod h1:nIBmBHtjTJvhZEn3e/EwVC/iMkR2tUX8hEiXjRBpH/s= github.com/flyteorg/stow v0.3.3/go.mod h1:HBld7ud0i4khMHwJjkO8v+NSP7ddKa/ruhf4I8fliaA= github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk= github.com/flyteorg/stow v0.3.6/go.mod h1:5dfBitPM004dwaZdoVylVjxFT4GWAgI0ghAndhNUzCo= @@ -1468,6 +1468,7 @@ github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtX github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/vektra/mockery v1.1.2/go.mod h1:VcfZjKaFOPO+MpN4ZvwPjs4c48lkq1o3Ym8yHZJu0jU= github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -1715,8 +1716,8 @@ golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220412020605-290c469a71a5/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/net v0.0.0-20220607020251-c690dde0001d h1:4SFsTMi4UahlKoloni7L4eYzhFRifURQLw+yv0QDCx8= -golang.org/x/net v0.0.0-20220607020251-c690dde0001d/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181003184128-c57b0facaced/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1749,8 +1750,9 @@ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180816055513-1c9583448a9c/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1864,8 +1866,8 @@ golang.org/x/sys v0.0.0-20220209214540-3681064d5158/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220328115105-d36c6a25d886/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68 h1:z8Hj/bl9cOV2grsOpEaQFUaly0JWN3i97mo3jXKJNp0= -golang.org/x/sys v0.0.0-20220608164250-635b8c9b7f68/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY= @@ -1967,6 +1969,7 @@ golang.org/x/tools v0.0.0-20200224181240-023911ca70b2/go.mod h1:TB2adYChydJhpapK golang.org/x/tools v0.0.0-20200227222343-706bc42d1f0d/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20200304193943-95d2e580d8eb/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= golang.org/x/tools v0.0.0-20200312045724-11d5b4c81c7d/go.mod h1:o4KQGtdN14AW+yjsvvwRTJJuXz8XRtIHtEnmAXLyFUw= +golang.org/x/tools v0.0.0-20200323144430-8dcfad9e016e/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= golang.org/x/tools v0.0.0-20200331025713-a30bf2db82d4/go.mod h1:Sl4aGygMT6LrqrWclx+PTx3U+LnKx/seiNR+3G19Ar8= golang.org/x/tools v0.0.0-20200501065659-ab2804fb9c9d/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20200505023115-26f46d2f7ef8/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -1998,8 +2001,8 @@ golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.10-0.20220218145154-897bd77cd717/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.1.11 h1:loJ25fNOEhSXfHrpoGj91eCUThwdNX6u24rO1xnNteY= -golang.org/x/tools v0.1.11/go.mod h1:SgwaegtQh8clINPpECJMqnxLv9I09HLqnW3RMqW0CA4= +golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/common/mocks/storage.go b/pkg/common/mocks/storage.go index 9edc5c1ed..d1555ebf4 100644 --- a/pkg/common/mocks/storage.go +++ b/pkg/common/mocks/storage.go @@ -24,7 +24,8 @@ type TestDataStore struct { ctx context.Context, reference storage.DataReference, opts storage.Options, msg proto.Message) error ConstructReferenceCb func( ctx context.Context, reference storage.DataReference, nestedKeys ...string) (storage.DataReference, error) - Store map[storage.DataReference][]byte + DeleteCb func(ctx context.Context, reference storage.DataReference) error + Store map[storage.DataReference][]byte } func (t *TestDataStore) Head(ctx context.Context, reference storage.DataReference) (storage.Metadata, error) { @@ -77,6 +78,10 @@ func (t *TestDataStore) ConstructReference( return storage.DataReference(fmt.Sprintf("%s/%v", reference, nestedPath)), nil } +func (t *TestDataStore) Delete(ctx context.Context, reference storage.DataReference) error { + return t.DeleteCb(ctx, reference) +} + func GetMockStorageClient() *storage.DataStore { mockStorageClient := TestDataStore{ Store: make(map[storage.DataReference][]byte), diff --git a/pkg/manager/impl/execution_manager.go b/pkg/manager/impl/execution_manager.go index f5ab8a0d2..ac513cc3f 100644 --- a/pkg/manager/impl/execution_manager.go +++ b/pkg/manager/impl/execution_manager.go @@ -1060,6 +1060,7 @@ func (m *ExecutionManager) RelaunchExecution( } executionSpec.Metadata.Mode = admin.ExecutionMetadata_RELAUNCH executionSpec.Metadata.ReferenceExecution = existingExecution.Id + executionSpec.OverwriteCache = request.GetOverwriteCache() var executionModel *models.Execution ctx, executionModel, err = m.launchExecutionAndPrepareModel(ctx, admin.ExecutionCreateRequest{ Project: request.Id.Project, diff --git a/pkg/manager/impl/execution_manager_test.go b/pkg/manager/impl/execution_manager_test.go index 309f3cde0..b841a1a37 100644 --- a/pkg/manager/impl/execution_manager_test.go +++ b/pkg/manager/impl/execution_manager_test.go @@ -1001,6 +1001,86 @@ func TestCreateExecutionInterruptible(t *testing.T) { } } +func TestCreateExecutionOverwriteCache(t *testing.T) { + tests := []struct { + name string + task bool + overwriteCache bool + want bool + }{ + { + name: "LaunchPlanDefault", + task: false, + overwriteCache: false, + want: false, + }, + { + name: "LaunchPlanEnable", + task: false, + overwriteCache: true, + want: true, + }, + { + name: "TaskDefault", + task: false, + overwriteCache: false, + want: false, + }, + { + name: "TaskEnable", + task: true, + overwriteCache: true, + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + request := testutils.GetExecutionRequest() + if tt.task { + request.Spec.LaunchPlan.ResourceType = core.ResourceType_TASK + } + request.Spec.OverwriteCache = tt.overwriteCache + + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + setDefaultTaskCallbackForExecTest(repository) + + exCreateFunc := func(ctx context.Context, input models.Execution) error { + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + + if tt.task { + assert.Equal(t, uint(0), input.LaunchPlanID) + assert.NotEqual(t, uint(0), input.TaskID) + } else { + assert.NotEqual(t, uint(0), input.LaunchPlanID) + assert.Equal(t, uint(0), input.TaskID) + } + + assert.Equal(t, tt.overwriteCache, spec.GetOverwriteCache()) + + return nil + } + + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + + _, err := execManager.CreateExecution(context.Background(), request, requestedAt) + assert.Nil(t, err) + }) + } +} + func makeExecutionGetFunc( t *testing.T, closureBytes []byte, startTime *time.Time) repositoryMocks.GetExecutionFunc { return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { @@ -1090,6 +1170,39 @@ func makeExecutionInterruptibleGetFunc( } } +func makeExecutionOverwriteCacheGetFunc( + t *testing.T, closureBytes []byte, startTime *time.Time, overwriteCache bool) repositoryMocks.GetExecutionFunc { + return func(ctx context.Context, input interfaces.Identifier) (models.Execution, error) { + assert.Equal(t, "project", input.Project) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "name", input.Name) + + request := testutils.GetExecutionRequest() + request.Spec.OverwriteCache = overwriteCache + + specBytes, err := proto.Marshal(request.Spec) + assert.Nil(t, err) + + return models.Execution{ + ExecutionKey: models.ExecutionKey{ + Project: "project", + Domain: "domain", + Name: "name", + }, + BaseModel: models.BaseModel{ + ID: uint(8), + }, + Spec: specBytes, + Phase: core.WorkflowExecution_QUEUED.String(), + Closure: closureBytes, + LaunchPlanID: uint(1), + WorkflowID: uint(2), + StartedAt: startTime, + Cluster: testCluster, + }, nil + } +} + func TestRelaunchExecution(t *testing.T) { // Set up mocks. repository := getMockRepositoryForExecTest() @@ -1280,6 +1393,129 @@ func TestRelaunchExecutionInterruptibleOverride(t *testing.T) { assert.True(t, createCalled) } +func TestRelaunchExecutionOverwriteCacheOverride(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_RUNNING, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + + t.Run("override enable", func(t *testing.T) { + executionGetFunc := makeExecutionOverwriteCacheGetFunc(t, existingClosureBytes, &startTime, false) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "relaunchy", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RELAUNCH, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RELAUNCH), input.Mode) + assert.True(t, spec.GetOverwriteCache()) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + asd, err := execManager.RelaunchExecution(context.Background(), admin.ExecutionRelaunchRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "relaunchy", + OverwriteCache: true, + }, requestedAt) + assert.Nil(t, err) + assert.NotNil(t, asd) + assert.True(t, createCalled) + }) + + t.Run("override disable", func(t *testing.T) { + executionGetFunc := makeExecutionOverwriteCacheGetFunc(t, existingClosureBytes, &startTime, true) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "relaunchy", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RELAUNCH, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RELAUNCH), input.Mode) + assert.False(t, spec.GetOverwriteCache()) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + asd, err := execManager.RelaunchExecution(context.Background(), admin.ExecutionRelaunchRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "relaunchy", + OverwriteCache: false, + }, requestedAt) + assert.Nil(t, err) + assert.NotNil(t, asd) + assert.True(t, createCalled) + }) + + t.Run("override omitted", func(t *testing.T) { + executionGetFunc := makeExecutionOverwriteCacheGetFunc(t, existingClosureBytes, &startTime, true) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "relaunchy", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RELAUNCH, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RELAUNCH), input.Mode) + assert.False(t, spec.GetOverwriteCache()) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + asd, err := execManager.RelaunchExecution(context.Background(), admin.ExecutionRelaunchRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "relaunchy", + }, requestedAt) + assert.Nil(t, err) + assert.NotNil(t, asd) + assert.True(t, createCalled) + }) +} + func TestRecoverExecution(t *testing.T) { // Set up mocks. repository := getMockRepositoryForExecTest() @@ -1580,6 +1816,67 @@ func TestRecoverExecutionInterruptibleOverride(t *testing.T) { assert.True(t, proto.Equal(expectedResponse, response)) } +func TestRecoverExecutionOverwriteCacheOverride(t *testing.T) { + // Set up mocks. + repository := getMockRepositoryForExecTest() + setDefaultLpCallbackForExecTest(repository) + mockExecutor := workflowengineMocks.WorkflowExecutor{} + mockExecutor.OnExecuteMatch(mock.Anything, mock.Anything, mock.Anything).Return(workflowengineInterfaces.ExecutionResponse{}, nil) + mockExecutor.OnID().Return("testMockExecutor") + r := plugins.NewRegistry() + r.RegisterDefault(plugins.PluginIDWorkflowExecutor, &mockExecutor) + execManager := NewExecutionManager(repository, r, getMockExecutionsConfigProvider(), getMockStorageForExecTest(context.Background()), mockScope.NewTestScope(), mockScope.NewTestScope(), &mockPublisher, mockExecutionRemoteURL, nil, nil, nil, nil, &eventWriterMocks.WorkflowExecutionEventWriter{}) + startTime := time.Now() + startTimeProto, _ := ptypes.TimestampProto(startTime) + existingClosure := admin.ExecutionClosure{ + Phase: core.WorkflowExecution_SUCCEEDED, + StartedAt: startTimeProto, + } + existingClosureBytes, _ := proto.Marshal(&existingClosure) + executionGetFunc := makeExecutionOverwriteCacheGetFunc(t, existingClosureBytes, &startTime, true) + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetGetCallback(executionGetFunc) + + var createCalled bool + exCreateFunc := func(ctx context.Context, input models.Execution) error { + createCalled = true + assert.Equal(t, "recovered", input.Name) + assert.Equal(t, "domain", input.Domain) + assert.Equal(t, "project", input.Project) + assert.Equal(t, uint(8), input.SourceExecutionID) + var spec admin.ExecutionSpec + err := proto.Unmarshal(input.Spec, &spec) + assert.Nil(t, err) + assert.Equal(t, admin.ExecutionMetadata_RECOVERED, spec.Metadata.Mode) + assert.Equal(t, int32(admin.ExecutionMetadata_RECOVERED), input.Mode) + assert.True(t, spec.GetOverwriteCache()) + return nil + } + repository.ExecutionRepo().(*repositoryMocks.MockExecutionRepo).SetCreateCallback(exCreateFunc) + + // Issue request. + response, err := execManager.RecoverExecution(context.Background(), admin.ExecutionRecoverRequest{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + }, + Name: "recovered", + }, requestedAt) + + // And verify response. + assert.Nil(t, err) + + expectedResponse := &admin.ExecutionCreateResponse{ + Id: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "recovered", + }, + } + assert.True(t, createCalled) + assert.True(t, proto.Equal(expectedResponse, response)) +} + func TestCreateWorkflowEvent(t *testing.T) { repository := repositoryMocks.NewMockRepository() startTime := time.Now() @@ -3997,6 +4294,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { requestK8sServiceAccount := "requestK8sServiceAccount" requestMaxParallelism := int32(10) requestInterruptible := false + requestOverwriteCache := false launchPlanLabels := map[string]string{"launchPlanLabelKey": "launchPlanLabelValue"} launchPlanAnnotations := map[string]string{"launchPlanAnnotationKey": "launchPlanAnnotationValue"} @@ -4005,6 +4303,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { launchPlanAssumableIamRole := "launchPlanAssumableIamRole" launchPlanMaxParallelism := int32(50) launchPlanInterruptible := true + launchPlanOverwriteCache := true applicationConfig := runtime.NewConfigurationProvider() @@ -4018,6 +4317,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { rmK8sServiceAccount := "rmK8sServiceAccount" rmMaxParallelism := int32(80) rmInterruptible := false + rmOverwriteCache := false resourceManager := managerMocks.MockResourceManager{} executionManager := ExecutionManager{ @@ -4041,6 +4341,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ MaxParallelism: rmMaxParallelism, Interruptible: &wrappers.BoolValue{Value: rmInterruptible}, + OverwriteCache: rmOverwriteCache, Annotations: &admin.Annotations{Values: rmAnnotations}, RawOutputDataConfig: &admin.RawOutputDataConfig{ OutputLocationPrefix: rmOutputLocationPrefix, @@ -4090,6 +4391,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { }, MaxParallelism: requestMaxParallelism, Interruptible: &wrappers.BoolValue{Value: requestInterruptible}, + OverwriteCache: requestOverwriteCache, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil) @@ -4097,6 +4399,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Equal(t, requestMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, requestK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, requestInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, requestOverwriteCache, execConfig.OverwriteCache) assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, requestLabels, execConfig.GetLabels().Values) assert.Equal(t, requestAnnotations, execConfig.GetAnnotations().Values) @@ -4126,12 +4429,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { }, MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, + OverwriteCache: launchPlanOverwriteCache, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) assert.NoError(t, err) assert.Equal(t, requestMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, launchPlanInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, launchPlanOverwriteCache, execConfig.OverwriteCache) assert.True(t, proto.Equal(launchPlan.Spec.SecurityContext, execConfig.SecurityContext)) assert.True(t, proto.Equal(launchPlan.Spec.Annotations, execConfig.Annotations)) assert.Equal(t, requestOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) @@ -4162,12 +4467,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { }, MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, + OverwriteCache: launchPlanOverwriteCache, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) assert.NoError(t, err) assert.Equal(t, launchPlanMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, launchPlanInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, launchPlanOverwriteCache, execConfig.OverwriteCache) assert.Equal(t, launchPlanK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, launchPlanOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, launchPlanLabels, execConfig.GetLabels().Values) @@ -4192,12 +4499,14 @@ func TestGetExecutionConfigOverrides(t *testing.T) { }, MaxParallelism: launchPlanMaxParallelism, Interruptible: &wrappers.BoolValue{Value: launchPlanInterruptible}, + OverwriteCache: launchPlanOverwriteCache, }, } execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) assert.NoError(t, err) assert.Equal(t, launchPlanMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, launchPlanInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, launchPlanOverwriteCache, execConfig.OverwriteCache) assert.Equal(t, launchPlanK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, launchPlanOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, launchPlanLabels, execConfig.GetLabels().Values) @@ -4228,6 +4537,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, launchPlanMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, rmInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, rmOverwriteCache, execConfig.OverwriteCache) assert.Equal(t, launchPlanK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, launchPlanOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, launchPlanLabels, execConfig.GetLabels().Values) @@ -4246,6 +4556,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, rmMaxParallelism, execConfig.MaxParallelism) assert.Equal(t, rmInterruptible, execConfig.Interruptible.Value) + assert.Equal(t, rmOverwriteCache, execConfig.OverwriteCache) assert.Equal(t, rmK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Equal(t, rmOutputLocationPrefix, execConfig.RawOutputDataConfig.OutputLocationPrefix) assert.Equal(t, rmLabels, execConfig.GetLabels().Values) @@ -4291,6 +4602,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, rmMaxParallelism, execConfig.MaxParallelism) assert.Nil(t, execConfig.GetInterruptible()) + assert.False(t, execConfig.OverwriteCache) assert.Equal(t, rmK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) @@ -4327,6 +4639,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) assert.Nil(t, execConfig.GetInterruptible()) + assert.False(t, execConfig.OverwriteCache) assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) @@ -4368,6 +4681,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) assert.Nil(t, execConfig.GetInterruptible()) + assert.False(t, execConfig.OverwriteCache) assert.Equal(t, deprecatedLaunchPlanK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) @@ -4393,6 +4707,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ MaxParallelism: 300, Interruptible: &wrappers.BoolValue{Value: true}, + OverwriteCache: true, SecurityContext: &core.SecurityContext{ RunAs: &core.Identity{ K8SServiceAccount: "workflowDefault", @@ -4419,6 +4734,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.NoError(t, err) assert.Equal(t, int32(300), execConfig.MaxParallelism) assert.True(t, execConfig.Interruptible.Value) + assert.True(t, execConfig.OverwriteCache) assert.Equal(t, "workflowDefault", execConfig.SecurityContext.RunAs.K8SServiceAccount) assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) @@ -4448,6 +4764,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) assert.Equal(t, fmt.Errorf("failed to fetch the resources"), err) assert.Nil(t, execConfig.GetInterruptible()) + assert.False(t, execConfig.GetOverwriteCache()) assert.Nil(t, execConfig.GetSecurityContext()) assert.Nil(t, execConfig.GetRawOutputDataConfig()) assert.Nil(t, execConfig.GetLabels()) @@ -4474,6 +4791,7 @@ func TestGetExecutionConfigOverrides(t *testing.T) { } executionManager.config.ApplicationConfiguration().GetTopLevelConfig().Interruptible = true + executionManager.config.ApplicationConfiguration().GetTopLevelConfig().OverwriteCache = true t.Run("request with interruptible override disabled", func(t *testing.T) { request := &admin.ExecutionCreateRequest{ @@ -4615,6 +4933,106 @@ func TestGetExecutionConfigOverrides(t *testing.T) { assert.Nil(t, execConfig.GetLabels()) assert.Nil(t, execConfig.GetAnnotations()) }) + t.Run("request with skip cache override enabled", func(t *testing.T) { + request := &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + OverwriteCache: true, + }, + } + + execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil) + assert.NoError(t, err) + assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) + assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) + assert.Nil(t, execConfig.GetRawOutputDataConfig()) + assert.Nil(t, execConfig.GetLabels()) + assert.Nil(t, execConfig.GetAnnotations()) + }) + t.Run("request with no skip cache override specified", func(t *testing.T) { + request := &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{}, + } + + execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, nil) + assert.NoError(t, err) + assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) + assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) + assert.Nil(t, execConfig.GetRawOutputDataConfig()) + assert.Nil(t, execConfig.GetLabels()) + assert.Nil(t, execConfig.GetAnnotations()) + }) + t.Run("launch plan with skip cache override enabled", func(t *testing.T) { + request := &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{}, + } + + launchPlan := &admin.LaunchPlan{ + Spec: &admin.LaunchPlanSpec{ + OverwriteCache: true, + }, + } + + execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) + assert.NoError(t, err) + assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) + assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) + assert.Nil(t, execConfig.GetRawOutputDataConfig()) + assert.Nil(t, execConfig.GetLabels()) + assert.Nil(t, execConfig.GetAnnotations()) + }) + t.Run("launch plan with no skip cache override specified", func(t *testing.T) { + request := &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{}, + } + + launchPlan := &admin.LaunchPlan{ + Spec: &admin.LaunchPlanSpec{}, + } + + execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) + assert.NoError(t, err) + assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) + assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) + assert.Nil(t, execConfig.GetRawOutputDataConfig()) + assert.Nil(t, execConfig.GetLabels()) + assert.Nil(t, execConfig.GetAnnotations()) + }) + t.Run("request and launch plan with different skip cache overrides", func(t *testing.T) { + request := &admin.ExecutionCreateRequest{ + Project: workflowIdentifier.Project, + Domain: workflowIdentifier.Domain, + Spec: &admin.ExecutionSpec{ + OverwriteCache: true, + }, + } + + launchPlan := &admin.LaunchPlan{ + Spec: &admin.LaunchPlanSpec{ + OverwriteCache: false, + }, + } + + execConfig, err := executionManager.getExecutionConfig(context.TODO(), request, launchPlan) + assert.NoError(t, err) + assert.Equal(t, defaultMaxParallelism, execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) + assert.Equal(t, defaultK8sServiceAccount, execConfig.SecurityContext.RunAs.K8SServiceAccount) + assert.Nil(t, execConfig.GetRawOutputDataConfig()) + assert.Nil(t, execConfig.GetLabels()) + assert.Nil(t, execConfig.GetAnnotations()) + }) }) } @@ -4635,6 +5053,7 @@ func TestGetExecutionConfig(t *testing.T) { Target: &admin.MatchingAttributes_WorkflowExecutionConfig{ WorkflowExecutionConfig: &admin.WorkflowExecutionConfig{ MaxParallelism: 100, + OverwriteCache: true, }, }, }, @@ -4653,6 +5072,7 @@ func TestGetExecutionConfig(t *testing.T) { }, nil) assert.NoError(t, err) assert.Equal(t, execConfig.MaxParallelism, int32(100)) + assert.True(t, execConfig.OverwriteCache) } func TestGetExecutionConfig_Spec(t *testing.T) { @@ -4671,14 +5091,17 @@ func TestGetExecutionConfig_Spec(t *testing.T) { Domain: workflowIdentifier.Domain, Spec: &admin.ExecutionSpec{ MaxParallelism: 100, + OverwriteCache: true, }, }, &admin.LaunchPlan{ Spec: &admin.LaunchPlanSpec{ MaxParallelism: 50, + OverwriteCache: false, // explicitly set to false for clarity }, }) assert.NoError(t, err) assert.Equal(t, int32(100), execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) execConfig, err = executionManager.getExecutionConfig(context.TODO(), &admin.ExecutionCreateRequest{ Project: workflowIdentifier.Project, @@ -4687,10 +5110,12 @@ func TestGetExecutionConfig_Spec(t *testing.T) { }, &admin.LaunchPlan{ Spec: &admin.LaunchPlanSpec{ MaxParallelism: 50, + OverwriteCache: true, }, }) assert.NoError(t, err) assert.Equal(t, int32(50), execConfig.MaxParallelism) + assert.True(t, execConfig.OverwriteCache) resourceManager = managerMocks.MockResourceManager{} resourceManager.GetResourceFunc = func(ctx context.Context, @@ -4702,6 +5127,8 @@ func TestGetExecutionConfig_Spec(t *testing.T) { config: applicationConfig, } + executionManager.config.ApplicationConfiguration().GetTopLevelConfig().OverwriteCache = true + execConfig, err = executionManager.getExecutionConfig(context.TODO(), &admin.ExecutionCreateRequest{ Project: workflowIdentifier.Project, Domain: workflowIdentifier.Domain, @@ -4711,6 +5138,7 @@ func TestGetExecutionConfig_Spec(t *testing.T) { }) assert.NoError(t, err) assert.Equal(t, execConfig.MaxParallelism, int32(25)) + assert.True(t, execConfig.OverwriteCache) } func TestGetClusterAssignment(t *testing.T) { diff --git a/pkg/manager/impl/shared/iface.go b/pkg/manager/impl/shared/iface.go index 8cf83882b..7baae65a1 100644 --- a/pkg/manager/impl/shared/iface.go +++ b/pkg/manager/impl/shared/iface.go @@ -22,4 +22,6 @@ type WorkflowExecutionConfigInterface interface { GetLabels() *admin.Labels // GetInterruptible indicates a workflow should be flagged as interruptible for a single execution. If omitted, the workflow's default is used. GetInterruptible() *wrappers.BoolValue + // GetOverwriteCache indicates a workflow should skip all its cached results and re-compute its output, overwriting any already stored data. + GetOverwriteCache() bool } diff --git a/pkg/manager/impl/util/shared.go b/pkg/manager/impl/util/shared.go index 8bf449af8..3528b889c 100644 --- a/pkg/manager/impl/util/shared.go +++ b/pkg/manager/impl/util/shared.go @@ -256,7 +256,7 @@ func GetMatchableResource(ctx context.Context, resourceManager interfaces.Resour } // MergeIntoExecConfig into workflowExecConfig (higher priority) from spec (lower priority) and return the -// a new object with the merged changes. +// new object with the merged changes. // After settings project is done, can move this function back to execution manager. Currently shared with resource. func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec shared.WorkflowExecutionConfigInterface) admin.WorkflowExecutionConfig { if workflowExecConfig.GetMaxParallelism() == 0 && spec.GetMaxParallelism() > 0 { @@ -291,5 +291,10 @@ func MergeIntoExecConfig(workflowExecConfig admin.WorkflowExecutionConfig, spec if workflowExecConfig.GetInterruptible() == nil && spec.GetInterruptible() != nil { workflowExecConfig.Interruptible = spec.GetInterruptible() } + + if !workflowExecConfig.GetOverwriteCache() && spec.GetOverwriteCache() { + workflowExecConfig.OverwriteCache = spec.GetOverwriteCache() + } + return workflowExecConfig } diff --git a/pkg/runtime/interfaces/application_configuration.go b/pkg/runtime/interfaces/application_configuration.go index 08e7a18ae..ee053b7a8 100644 --- a/pkg/runtime/interfaces/application_configuration.go +++ b/pkg/runtime/interfaces/application_configuration.go @@ -78,6 +78,11 @@ type ApplicationConfig struct { Annotations map[string]string `json:"annotations,omitempty"` // Interruptible indicates whether all tasks should be run as interruptible by default (unless specified otherwise via the execution/workflow/task definition) Interruptible bool `json:"interruptible"` + // OverwriteCache indicates all workflows and tasks should skip all their cached results and re-compute their outputs, + // overwriting any already stored data. + // Note that setting this setting to `true` effectively disabled all caching in Flyte as all executions launched + // will have their OverwriteCache setting enabled. + OverwriteCache bool `json:"overwriteCache"` // Optional: security context override to apply this execution. // iam_role references the fully qualified name of Identity & Access Management role to impersonate. @@ -158,11 +163,17 @@ func (a *ApplicationConfig) GetInterruptible() *wrappers.BoolValue { } } +func (a *ApplicationConfig) GetOverwriteCache() bool { + return a.OverwriteCache +} + // GetAsWorkflowExecutionConfig returns the WorkflowExecutionConfig as extracted from this object func (a *ApplicationConfig) GetAsWorkflowExecutionConfig() admin.WorkflowExecutionConfig { - // These two should always be set, one is a number, and the other returns nil when empty. + // These values should always be set as their fallback values equals to their zero value or nil, + // providing a sensible default even if the actual value was not set. wec := admin.WorkflowExecutionConfig{ MaxParallelism: a.GetMaxParallelism(), + OverwriteCache: a.GetOverwriteCache(), Interruptible: a.GetInterruptible(), } diff --git a/pkg/workflowengine/impl/prepare_execution.go b/pkg/workflowengine/impl/prepare_execution.go index 43499922b..f2a778e27 100644 --- a/pkg/workflowengine/impl/prepare_execution.go +++ b/pkg/workflowengine/impl/prepare_execution.go @@ -57,10 +57,13 @@ func addExecutionOverrides(taskPluginOverrides []*admin.PluginOverride, } if workflowExecutionConfig != nil { executionConfig.MaxParallelism = uint32(workflowExecutionConfig.MaxParallelism) + if workflowExecutionConfig.GetInterruptible() != nil { interruptible := workflowExecutionConfig.GetInterruptible().GetValue() executionConfig.Interruptible = &interruptible } + + executionConfig.OverwriteCache = workflowExecutionConfig.GetOverwriteCache() } if taskResources != nil { var requests = v1alpha1.TaskResourceSpec{} diff --git a/pkg/workflowengine/impl/prepare_execution_test.go b/pkg/workflowengine/impl/prepare_execution_test.go index bba2b1066..38e155636 100644 --- a/pkg/workflowengine/impl/prepare_execution_test.go +++ b/pkg/workflowengine/impl/prepare_execution_test.go @@ -158,6 +158,14 @@ func TestAddExecutionOverrides(t *testing.T) { assert.NotNil(t, workflow.ExecutionConfig.Interruptible) assert.True(t, *workflow.ExecutionConfig.Interruptible) }) + t.Run("skip cache", func(t *testing.T) { + workflowExecutionConfig := &admin.WorkflowExecutionConfig{ + OverwriteCache: true, + } + workflow := &v1alpha1.FlyteWorkflow{} + addExecutionOverrides(nil, workflowExecutionConfig, nil, nil, workflow) + assert.True(t, workflow.ExecutionConfig.OverwriteCache) + }) } func TestPrepareFlyteWorkflow(t *testing.T) {