diff --git a/enterprise/server/remote_execution/container/container.go b/enterprise/server/remote_execution/container/container.go index 85e2d9c2faa..e24b8a31d65 100644 --- a/enterprise/server/remote_execution/container/container.go +++ b/enterprise/server/remote_execution/container/container.go @@ -64,6 +64,7 @@ var ( ErrRemoved = status.UnavailableError("container has been removed") recordCPUTimelines = flag.Bool("executor.record_cpu_timelines", false, "Capture CPU timeseries data in UsageStats for each task.") + imagePullTimeout = flag.Duration("executor.image_pull_timeout", 5*time.Minute, "How long to wait for the container image to be pulled before returning an Unavailable (retryable) error for an action execution attempt. Applies to all isolation types (docker, firecracker, etc.)") debugUseLocalImagesOnly = flag.Bool("debug_use_local_images_only", false, "Do not pull OCI images and only used locally cached images. This can be set to test local image builds during development without needing to push to a container registry. Not intended for production use.") DebugEnableAnonymousRecycling = flag.Bool("debug_enable_anonymous_runner_recycling", false, "Whether to enable runner recycling for unauthenticated requests. For debugging purposes only - do not use in production.") @@ -513,6 +514,25 @@ func PullImageIfNecessary(ctx context.Context, env environment.Env, ctr CommandC ctx, span := tracing.StartSpan(ctx) defer span.End() + + if *imagePullTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, *imagePullTimeout) + defer cancel() + } + + if err := pullImageIfNecessary(ctx, env, ctr, creds, imageRef); err != nil { + // make sure we always return Unavailable if the context deadline + // was exceeded + if err == context.DeadlineExceeded || ctx.Err() != nil { + return status.UnavailableErrorf("%s", status.Message(err)) + } + return err + } + return nil +} + +func pullImageIfNecessary(ctx context.Context, env environment.Env, ctr CommandContainer, creds oci.Credentials, imageRef string) error { cacheAuth := env.GetImageCacheAuthenticator() if cacheAuth == nil || env.GetAuthenticator() == nil { // If we don't have an authenticator available, fall back to diff --git a/enterprise/server/remote_execution/runner/runner_test.go b/enterprise/server/remote_execution/runner/runner_test.go index c3691308427..7e4b288c374 100644 --- a/enterprise/server/remote_execution/runner/runner_test.go +++ b/enterprise/server/remote_execution/runner/runner_test.go @@ -72,6 +72,9 @@ type fakeContainer struct { CreateError error Removed chan struct{} Result *interfaces.CommandResult + Isolation string // Fake isolation type name + ImageCached bool // Return value for IsImageCached + BlockPull bool // PullImage blocks forever if true. } func NewFakeContainer() *fakeContainer { @@ -81,11 +84,26 @@ func NewFakeContainer() *fakeContainer { } } +func (c *fakeContainer) IsolationType() string { + if c.Isolation == "" { + return "bare" + } + return c.Isolation +} + func (c *fakeContainer) Run(ctx context.Context, cmd *repb.Command, workdir string, creds oci.Credentials) *interfaces.CommandResult { return c.Result } +func (c *fakeContainer) IsImageCached(ctx context.Context) (bool, error) { + return c.ImageCached, nil +} + func (c *fakeContainer) PullImage(ctx context.Context, creds oci.Credentials) error { + if c.BlockPull { + <-ctx.Done() + return ctx.Err() + } return nil } @@ -867,3 +885,34 @@ func TestDoNotRecycleSpecialFile(t *testing.T) { }) } } + +func TestImagePullTimeout(t *testing.T) { + // Enable OCI isolation so we can pull images. + flags.Set(t, "executor.enable_oci", true) + // Time out image pulls immediately + flags.Set(t, "executor.image_pull_timeout", 1*time.Nanosecond) + + env := newTestEnv(t) + cfg := noLimitsCfg() + cfg.ContainerProvider = providerFunc(func(ctx context.Context, args *container.Init) (container.CommandContainer, error) { + ctr := NewFakeContainer() + ctr.BlockPull = true + ctr.Isolation = "oci" + return ctr, nil + }) + pool := newRunnerPool(t, env, cfg) + ctx := withAuthenticatedUser(t, context.Background(), env, "US1") + task := newTask() + plat := task.ExecutionTask.Command.Platform + plat.Properties = append(plat.Properties, []*repb.Platform_Property{ + {Name: "container-image", Value: "docker://busybox"}, + {Name: "workload-isolation-type", Value: "oci"}, + }...) + r, err := pool.Get(ctx, task) + require.NoError(t, err) + + err = r.PrepareForTask(ctx) + require.Error(t, err) + assert.True(t, status.IsUnavailableError(err), "expected Unavailable, got %T", err) + assert.Contains(t, err.Error(), "deadline exceeded") +}