diff --git a/cmd/terraform-demux/main.go b/cmd/terraform-demux/main.go index 39b9816..5992295 100644 --- a/cmd/terraform-demux/main.go +++ b/cmd/terraform-demux/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "io" "log" "os" @@ -26,6 +27,11 @@ func main() { log.Printf("terraform-demux version %s, using arch '%s'", version, arch) + if err := checkStateCommand(os.Args); err != nil { + log.SetOutput(os.Stderr) + log.Fatal("error: ", err) + } + exitCode, err := wrapper.RunTerraform(os.Args[1:], arch) if err != nil { @@ -36,3 +42,19 @@ func main() { os.Exit(exitCode) } + +func checkStateCommand(args []string) error { + if checkArgsExists(args, "state") && !checkArgsExists(args, "--force") { + return errors.New("--force flag is required for the 'state' command") + } + return nil +} + +func checkArgsExists(args []string, cmd string) bool { + for _, arg := range args { + if arg == cmd { + return true + } + } + return false +} diff --git a/cmd/terraform-demux/main_test.go b/cmd/terraform-demux/main_test.go new file mode 100644 index 0000000..01394b6 --- /dev/null +++ b/cmd/terraform-demux/main_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "testing" +) + +func TestCheckStateCommand(t *testing.T) { + t.Run("Valid state command with --force flag after state command", func(t *testing.T) { + args := []string{"terraform", "state", "--force", "list"} + err := checkStateCommand(args) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + }) + + t.Run("Valid state command with --force flag before state command", func(t *testing.T) { + args := []string{"terraform", "--force", "state", "pull"} + err := checkStateCommand(args) + if err != nil { + t.Errorf("Expected no error, got: %v", err) + } + }) + + t.Run("Invalid state command without --force flag", func(t *testing.T) { + args := []string{"terraform", "state", "list"} + err := checkStateCommand(args) + expectedError := "--force flag is required for the 'state' command" + if err == nil || err.Error() != expectedError { + t.Errorf("Expected error: %s, got: %v", expectedError, err) + } + }) +}