diff --git a/internal/db/start/start.go b/internal/db/start/start.go index 95cfa85ec..8adb00745 100644 --- a/internal/db/start/start.go +++ b/internal/db/start/start.go @@ -45,7 +45,9 @@ func Run(ctx context.Context, fsys afero.Fs) error { utils.Config.Analytics.Enabled = false err := StartDatabase(ctx, fsys, os.Stderr) if err != nil { - utils.DockerRemoveAll(context.Background()) + if err := utils.DockerRemoveAll(context.Background(), io.Discard); err != nil { + fmt.Fprintln(os.Stderr, err) + } } return err } @@ -126,8 +128,8 @@ func StartDatabase(ctx context.Context, fsys afero.Fs, w io.Writer, options ...f } // Creating volume will not override existing volume, so we must inspect explicitly _, err := utils.Docker.VolumeInspect(ctx, utils.DbId) - noBackupVolume := client.IsErrNotFound(err) - if noBackupVolume { + utils.NoBackupVolume = client.IsErrNotFound(err) + if utils.NoBackupVolume { fmt.Fprintln(w, "Starting database...") } else { fmt.Fprintln(w, "Starting database from backup...") @@ -139,7 +141,7 @@ func StartDatabase(ctx context.Context, fsys afero.Fs, w io.Writer, options ...f return errors.New(ErrDatabase) } // Initialize if we are on PG14 and there's no existing db volume - if noBackupVolume { + if utils.NoBackupVolume { if err := setupDatabase(ctx, fsys, w, options...); err != nil { return err } diff --git a/internal/db/start/start_test.go b/internal/db/start/start_test.go index 81b92d634..be7c8d10b 100644 --- a/internal/db/start/start_test.go +++ b/internal/db/start/start_test.go @@ -50,13 +50,7 @@ func TestInitBranch(t *testing.T) { } func TestStartDatabase(t *testing.T) { - teardown := func() { - utils.Containers = []string{} - utils.Volumes = []string{} - } - t.Run("initialize main branch", func(t *testing.T) { - defer teardown() utils.Config.Db.MajorVersion = 15 utils.Config.Db.Image = utils.Pg15Image utils.DbId = "supabase_db_test" @@ -108,7 +102,6 @@ func TestStartDatabase(t *testing.T) { }) t.Run("recover from backup volume", func(t *testing.T) { - defer teardown() utils.Config.Db.MajorVersion = 14 utils.Config.Db.Image = utils.Pg15Image utils.DbId = "supabase_db_test" @@ -145,7 +138,6 @@ func TestStartDatabase(t *testing.T) { }) t.Run("throws error on start failure", func(t *testing.T) { - defer teardown() utils.Config.Db.MajorVersion = 15 utils.Config.Db.Image = utils.Pg15Image utils.DbId = "supabase_db_test" @@ -241,9 +233,7 @@ func TestStartCommand(t *testing.T) { Get("/v" + utils.Docker.ClientVersion() + "/images/" + utils.GetRegistryImageUrl(utils.Pg15Image) + "/json"). ReplyError(errors.New("network error")) // Cleanup resources - gock.New(utils.Docker.DaemonHost()). - Post("/v" + utils.Docker.ClientVersion() + "/networks/prune"). - Reply(http.StatusOK) + apitest.MockDockerStop(utils.Docker) // Run test err := Run(context.Background(), fsys) // Check error diff --git a/internal/start/start.go b/internal/start/start.go index f43e2e5b6..98c958cab 100644 --- a/internal/start/start.go +++ b/internal/start/start.go @@ -5,6 +5,7 @@ import ( "context" _ "embed" "fmt" + "io" "os" "path" "path/filepath" @@ -71,7 +72,9 @@ func Run(ctx context.Context, fsys afero.Fs, excludedContainers []string, ignore if ignoreHealthCheck && errors.Is(err, reset.ErrUnhealthy) { fmt.Fprintln(os.Stderr, err) } else { - utils.DockerRemoveAll(context.Background()) + if err := utils.DockerRemoveAll(context.Background(), io.Discard); err != nil { + fmt.Fprintln(os.Stderr, err) + } return err } } diff --git a/internal/stop/stop.go b/internal/stop/stop.go index 1bd6974c9..3e6f8d6ac 100644 --- a/internal/stop/stop.go +++ b/internal/stop/stop.go @@ -5,12 +5,7 @@ import ( _ "embed" "fmt" "io" - "os" - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/errdefs" - "github.com/go-errors/errors" "github.com/spf13/afero" "github.com/supabase/cli/internal/utils" ) @@ -33,67 +28,14 @@ func Run(ctx context.Context, backup bool, projectId string, fsys afero.Fs) erro } fmt.Println("Stopped " + utils.Aqua("supabase") + " local development setup.") + if backup { + listVolume := fmt.Sprintf("docker volume ls --filter label=%s=%s", utils.CliProjectLabel, utils.Config.ProjectId) + utils.CmdSuggestion = "Local data are backed up to docker volume. You may list them with " + utils.Aqua(listVolume) + } return nil } func stop(ctx context.Context, backup bool, w io.Writer) error { - args := utils.CliProjectFilter() - containers, err := utils.Docker.ContainerList(ctx, types.ContainerListOptions{ - All: true, - Filters: args, - }) - if err != nil { - return errors.Errorf("failed to list containers: %w", err) - } - // Gracefully shutdown containers - var ids []string - for _, c := range containers { - if c.State == "running" { - ids = append(ids, c.ID) - } - } - fmt.Fprintln(w, "Stopping containers...") - result := utils.WaitAll(ids, func(id string) error { - if err := utils.Docker.ContainerStop(ctx, id, container.StopOptions{}); err != nil { - return errors.Errorf("failed to stop container: %w", err) - } - return nil - }) - if err := errors.Join(result...); err != nil { - return err - } - if _, err := utils.Docker.ContainersPrune(ctx, args); err != nil { - return errors.Errorf("failed to prune containers: %w", err) - } - // Remove named volumes - if backup { - fmt.Fprintln(os.Stderr, "Postgres database saved to volume:", utils.DbId) - fmt.Fprintln(os.Stderr, "Postgres config saved to volume:", utils.ConfigId) - fmt.Fprintln(os.Stderr, "Storage directory saved to volume:", utils.StorageId) - fmt.Fprintln(os.Stderr, "Functions cache saved to volume:", utils.EdgeRuntimeId) - fmt.Fprintln(os.Stderr, "Inbucket emails saved to volume:", utils.InbucketId) - } else { - // TODO: label named volumes to use VolumesPrune for branch support - volumes := []string{ - utils.ConfigId, - utils.DbId, - utils.StorageId, - utils.EdgeRuntimeId, - utils.InbucketId, - } - result = utils.WaitAll(volumes, func(name string) error { - if err := utils.Docker.VolumeRemove(ctx, name, true); err != nil && !errdefs.IsNotFound(err) { - return errors.Errorf("Failed to remove volume %s: %w", name, err) - } - return nil - }) - if err := errors.Join(result...); err != nil { - return err - } - } - // Remove networks. - if _, err = utils.Docker.NetworksPrune(ctx, args); err != nil { - return errors.Errorf("failed to prune networks: %w", err) - } - return nil + utils.NoBackupVolume = !backup + return utils.DockerRemoveAll(ctx, w) } diff --git a/internal/stop/stop_test.go b/internal/stop/stop_test.go index 47b645e30..b935220da 100644 --- a/internal/stop/stop_test.go +++ b/internal/stop/stop_test.go @@ -108,33 +108,7 @@ func TestStopServices(t *testing.T) { // Setup mock docker require.NoError(t, apitest.MockDocker(utils.Docker)) defer gock.OffAll() - gock.New(utils.Docker.DaemonHost()). - Get("/v" + utils.Docker.ClientVersion() + "/containers/json"). - Reply(http.StatusOK). - JSON([]types.Container{}) - gock.New(utils.Docker.DaemonHost()). - Post("/v" + utils.Docker.ClientVersion() + "/containers/prune"). - Reply(http.StatusOK). - JSON(types.ContainersPruneReport{}) - gock.New(utils.Docker.DaemonHost()). - Delete("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.ConfigId). - Reply(http.StatusOK) - gock.New(utils.Docker.DaemonHost()). - Delete("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.DbId). - Reply(http.StatusOK) - gock.New(utils.Docker.DaemonHost()). - Delete("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.StorageId). - Reply(http.StatusNotFound) - gock.New(utils.Docker.DaemonHost()). - Delete("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.EdgeRuntimeId). - Reply(http.StatusNotFound) - gock.New(utils.Docker.DaemonHost()). - Delete("/v" + utils.Docker.ClientVersion() + "/volumes/" + utils.InbucketId). - Reply(http.StatusNotFound) - gock.New(utils.Docker.DaemonHost()). - Post("/v" + utils.Docker.ClientVersion() + "/networks/prune"). - Reply(http.StatusOK). - JSON(types.NetworksPruneReport{}) + apitest.MockDockerStop(utils.Docker) // Run test err := stop(context.Background(), false, io.Discard) // Check error diff --git a/internal/testing/apitest/helper.go b/internal/testing/apitest/helper.go index 5c6d79683..988393fec 100644 --- a/internal/testing/apitest/helper.go +++ b/internal/testing/apitest/helper.go @@ -8,6 +8,7 @@ import ( "github.com/docker/docker/api" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/volume" "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" "gopkg.in/h2non/gock.v1" @@ -39,6 +40,11 @@ func MockDockerStart(docker *client.Client, image, containerID string) { Post("/v" + docker.ClientVersion() + "/networks/create"). Reply(http.StatusCreated). JSON(types.NetworkCreateResponse{}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/volumes/create"). + Persist(). + Reply(http.StatusCreated). + JSON(volume.Volume{}) gock.New(docker.DaemonHost()). Post("/v" + docker.ClientVersion() + "/containers/create"). Reply(http.StatusOK). @@ -48,31 +54,31 @@ func MockDockerStart(docker *client.Client, image, containerID string) { Reply(http.StatusAccepted) } -// Ref: internal/utils/docker.go::DockerRunOnce -func MockDockerLogs(docker *client.Client, containerID, stdout string) error { - var body bytes.Buffer - writer := stdcopy.NewStdWriter(&body, stdcopy.Stdout) - _, err := writer.Write([]byte(stdout)) +// Ref: internal/utils/docker.go::DockerRemoveAll +func MockDockerStop(docker *client.Client) { gock.New(docker.DaemonHost()). - Get("/v"+docker.ClientVersion()+"/containers/"+containerID+"/logs"). + Get("/v" + docker.ClientVersion() + "/containers/json"). Reply(http.StatusOK). - SetHeader("Content-Type", "application/vnd.docker.raw-stream"). - Body(&body) + JSON([]types.Container{}) gock.New(docker.DaemonHost()). - Get("/v" + docker.ClientVersion() + "/containers/" + containerID + "/json"). + Post("/v" + docker.ClientVersion() + "/containers/prune"). Reply(http.StatusOK). - JSON(types.ContainerJSONBase{State: &types.ContainerState{ExitCode: 0}}) + JSON(types.ContainersPruneReport{}) gock.New(docker.DaemonHost()). - Delete("/v" + docker.ClientVersion() + "/containers/" + containerID). - Reply(http.StatusOK) - return err + Post("/v" + docker.ClientVersion() + "/volumes/prune"). + Reply(http.StatusOK). + JSON(types.VolumesPruneReport{}) + gock.New(docker.DaemonHost()). + Post("/v" + docker.ClientVersion() + "/networks/prune"). + Reply(http.StatusOK). + JSON(types.NetworksPruneReport{}) } // Ref: internal/utils/docker.go::DockerRunOnce -func MockDockerLogsExitCode(docker *client.Client, containerID string, exitCode int) error { +func setupDockerLogs(docker *client.Client, containerID, stdout string, exitCode int) error { var body bytes.Buffer writer := stdcopy.NewStdWriter(&body, stdcopy.Stdout) - _, err := writer.Write([]byte("")) + _, err := writer.Write([]byte(stdout)) gock.New(docker.DaemonHost()). Get("/v"+docker.ClientVersion()+"/containers/"+containerID+"/logs"). Reply(http.StatusOK). @@ -81,13 +87,23 @@ func MockDockerLogsExitCode(docker *client.Client, containerID string, exitCode gock.New(docker.DaemonHost()). Get("/v" + docker.ClientVersion() + "/containers/" + containerID + "/json"). Reply(http.StatusOK). - JSON(types.ContainerJSONBase{State: &types.ContainerState{ExitCode: exitCode}}) + JSON(types.ContainerJSONBase{State: &types.ContainerState{ + ExitCode: exitCode, + }}) gock.New(docker.DaemonHost()). Delete("/v" + docker.ClientVersion() + "/containers/" + containerID). Reply(http.StatusOK) return err } +func MockDockerLogs(docker *client.Client, containerID, stdout string) error { + return setupDockerLogs(docker, containerID, stdout, 0) +} + +func MockDockerLogsExitCode(docker *client.Client, containerID string, exitCode int) error { + return setupDockerLogs(docker, containerID, "", exitCode) +} + func ListUnmatchedRequests() []string { result := make([]string, len(gock.GetUnmatchedRequests())) for i, r := range gock.GetUnmatchedRequests() { diff --git a/internal/utils/docker.go b/internal/utils/docker.go index 986c25994..0bf7c7bd3 100644 --- a/internal/utils/docker.go +++ b/internal/utils/docker.go @@ -1,7 +1,6 @@ package utils import ( - "archive/tar" "bytes" "context" "encoding/base64" @@ -26,6 +25,7 @@ import ( "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/mount" "github.com/docker/docker/api/types/network" + "github.com/docker/docker/api/types/volume" "github.com/docker/docker/client" "github.com/docker/docker/errdefs" "github.com/docker/docker/pkg/jsonmessage" @@ -60,7 +60,7 @@ func AssertDockerIsRunning(ctx context.Context) error { } const ( - cliProjectLabel = "com.supabase.cli.project" + CliProjectLabel = "com.supabase.cli.project" composeProjectLabel = "com.docker.compose.projecta" ) @@ -71,7 +71,7 @@ func DockerNetworkCreateIfNotExists(ctx context.Context, networkId string) error types.NetworkCreate{ CheckDuplicate: true, Labels: map[string]string{ - cliProjectLabel: Config.ProjectId, + CliProjectLabel: Config.ProjectId, composeProjectLabel: Config.ProjectId, }, }, @@ -86,21 +86,12 @@ func DockerNetworkCreateIfNotExists(ctx context.Context, networkId string) error return err } -// Used by unit tests -// NOTE: There's a risk of data race with reads & writes from `DockerRun` and -// reads from `DockerRemoveAll`, but since they're expected to be run on the -// same thread, this is fine. -var ( - Containers []string - Volumes []string -) - -func WaitAll(containers []string, exec func(container string) error) []error { +func WaitAll[T any](containers []T, exec func(container T) error) []error { var wg sync.WaitGroup result := make([]error, len(containers)) for i, container := range containers { wg.Add(1) - go func(i int, container string) { + go func(i int, container T) { defer wg.Done() result[i] = exec(container) }(i, container) @@ -109,57 +100,57 @@ func WaitAll(containers []string, exec func(container string) error) []error { return result } -func DockerRemoveAll(ctx context.Context) { - _ = WaitAll(Containers, func(container string) error { - return Docker.ContainerRemove(ctx, container, types.ContainerRemoveOptions{ - RemoveVolumes: true, - Force: true, - }) - }) - _ = WaitAll(Volumes, func(name string) error { - return Docker.VolumeRemove(ctx, name, true) - }) - _, _ = Docker.NetworksPrune(ctx, CliProjectFilter()) -} +// TODO: encapsulate this state in a class +var NoBackupVolume = false -func CliProjectFilter() filters.Args { - return filters.NewArgs( - filters.Arg("label", cliProjectLabel+"="+Config.ProjectId), - ) -} - -func DockerAddFile(ctx context.Context, container string, fileName string, content []byte) error { - var buf bytes.Buffer - tw := tar.NewWriter(&buf) - err := tw.WriteHeader(&tar.Header{ - Name: fileName, - Mode: 0777, - Size: int64(len(content)), +func DockerRemoveAll(ctx context.Context, w io.Writer) error { + args := CliProjectFilter() + containers, err := Docker.ContainerList(ctx, types.ContainerListOptions{ + All: true, + Filters: args, }) - if err != nil { - return errors.Errorf("failed to copy file: %w", err) + return errors.Errorf("failed to list containers: %w", err) } - - _, err = tw.Write(content) - - if err != nil { - return errors.Errorf("failed to copy file: %w", err) + // Gracefully shutdown containers + var ids []string + for _, c := range containers { + if c.State == "running" { + ids = append(ids, c.ID) + } } - - err = tw.Close() - - if err != nil { - return errors.Errorf("failed to copy file: %w", err) + fmt.Fprintln(w, "Stopping containers...") + result := WaitAll(ids, func(id string) error { + if err := Docker.ContainerStop(ctx, id, container.StopOptions{}); err != nil { + return errors.Errorf("failed to stop container: %w", err) + } + return nil + }) + if err := errors.Join(result...); err != nil { + return err } - - err = Docker.CopyToContainer(ctx, container, "/tmp", &buf, types.CopyToContainerOptions{}) - if err != nil { - return errors.Errorf("failed to copy file: %w", err) + if _, err := Docker.ContainersPrune(ctx, args); err != nil { + return errors.Errorf("failed to prune containers: %w", err) + } + // Remove named volumes + if NoBackupVolume { + if _, err := Docker.VolumesPrune(ctx, args); err != nil { + return errors.Errorf("failed to prune volumes: %w", err) + } + } + // Remove networks. + if _, err = Docker.NetworksPrune(ctx, args); err != nil { + return errors.Errorf("failed to prune networks: %w", err) } return nil } +func CliProjectFilter() filters.Args { + return filters.NewArgs( + filters.Arg("label", CliProjectLabel+"="+Config.ProjectId), + ) +} + var ( // Only supports one registry per command invocation registryAuth string @@ -264,7 +255,7 @@ func DockerStart(ctx context.Context, config container.Config, hostConfig contai if config.Labels == nil { config.Labels = map[string]string{} } - config.Labels[cliProjectLabel] = Config.ProjectId + config.Labels[CliProjectLabel] = Config.ProjectId config.Labels[composeProjectLabel] = Config.ProjectId if len(hostConfig.NetworkMode) == 0 { hostConfig.NetworkMode = container.NetworkMode(NetId) @@ -275,38 +266,37 @@ func DockerStart(ctx context.Context, config container.Config, hostConfig contai return "", err } } + var binds, sources []string + for _, bind := range hostConfig.Binds { + spec, err := loader.ParseVolume(bind) + if err != nil { + return "", errors.Errorf("failed to parse docker volume: %w", err) + } + if spec.Type != string(mount.TypeVolume) { + binds = append(binds, bind) + } else if len(spec.Source) > 0 { + sources = append(sources, spec.Source) + } + } // Skip named volume for BitBucket pipeline - bitbucket := os.Getenv("BITBUCKET_CLONE_DIR") - if len(bitbucket) > 0 { - var binds []string - for _, bind := range hostConfig.Binds { - spec, err := loader.ParseVolume(bind) - if err != nil { - return "", errors.Errorf("failed to parse docker volume: %w", err) - } - if spec.Type != string(mount.TypeVolume) { - binds = append(binds, bind) + if os.Getenv("BITBUCKET_CLONE_DIR") != "" { + hostConfig.Binds = binds + } else { + // Create named volumes with labels + for _, name := range sources { + if _, err := Docker.VolumeCreate(ctx, volume.CreateOptions{ + Name: name, + Labels: config.Labels, + }); err != nil { + return "", errors.Errorf("failed to create volume: %w", err) } } - hostConfig.Binds = binds } // Create container from image resp, err := Docker.ContainerCreate(ctx, &config, &hostConfig, &networkingConfig, nil, containerName) if err != nil { return "", errors.Errorf("failed to create docker container: %w", err) } - // Track container id for cleanup - Containers = append(Containers, resp.ID) - for _, bind := range hostConfig.Binds { - spec, err := loader.ParseVolume(bind) - if err != nil { - return "", errors.Errorf("failed to parse docker volume: %w", err) - } - // Track named volumes for cleanup - if len(spec.Source) > 0 && spec.Type == string(mount.TypeVolume) { - Volumes = append(Volumes, spec.Source) - } - } // Run container in background err = Docker.ContainerStart(ctx, resp.ID, types.ContainerStartOptions{}) if err != nil { @@ -451,7 +441,7 @@ func suggestDockerStop(ctx context.Context, hostPort string) string { for _, c := range containers { for _, p := range c.Ports { if fmt.Sprintf("%s:%d", p.IP, p.PublicPort) == hostPort { - if project, ok := c.Labels[cliProjectLabel]; ok { + if project, ok := c.Labels[CliProjectLabel]; ok { return "\nTry stopping the running project with " + Aqua("supabase stop --project-id "+project) } else { name := c.ID