From 1382a6c557fcf9c5dfe9017c05b465b04bd63b2d Mon Sep 17 00:00:00 2001 From: Jakob Gray <20209054+JakobGray@users.noreply.github.com> Date: Mon, 21 Oct 2024 11:41:24 -0400 Subject: [PATCH] Require provider, do not default to AWS, and check for provider-specific flags (#678) --- cmd/ocm/create/cluster/cmd.go | 30 ++++++++++++++++++--- pkg/arguments/arguments.go | 50 +++++++++++++++++++++++++++++++++-- 2 files changed, 74 insertions(+), 6 deletions(-) diff --git a/cmd/ocm/create/cluster/cmd.go b/cmd/ocm/create/cluster/cmd.go index dccf4afe..10fed637 100644 --- a/cmd/ocm/create/cluster/cmd.go +++ b/cmd/ocm/create/cluster/cmd.go @@ -615,9 +615,6 @@ func preRun(cmd *cobra.Command, argv []string) error { return err } - // Only offer the 2 providers known to support OSD now; - // but don't validate if set, to not block `ocm` CLI from creating clusters on future providers. - providers, _ := osdProviderOptions(connection) // If marketplace-gcp subscription type is used, provider can only be GCP gcpBillingModel, _ := billing.GetBillingModel(connection, billing.MarketplaceGcpSubscriptionType) gcpSubscriptionTypeTemplate := subscriptionTypeOption(gcpBillingModel.ID(), gcpBillingModel.Description()) @@ -625,6 +622,10 @@ func preRun(cmd *cobra.Command, argv []string) error { parseSubscriptionType(args.subscriptionType) == parseSubscriptionType(gcpSubscriptionTypeTemplate.Value) if isGcpMarketplace { + if args.provider != c.ProviderGCP && args.provider != "" { + return fmt.Errorf("Provider must be set to %s when using %s subscription type", + c.ProviderGCP, billing.MarketplaceGcpSubscriptionType) + } fmt.Println("setting provider to", c.ProviderGCP) args.provider = c.ProviderGCP fmt.Println("setting ccs to 'true'") @@ -643,12 +644,17 @@ func preRun(cmd *cobra.Command, argv []string) error { return fmt.Errorf(gcpTermsAgreementNonInteractiveError) } } else { - err = arguments.PromptOneOf(fs, "provider", providers) + err = promptProvider(fs, connection) if err != nil { return err } } + err = arguments.CheckIgnoredProviderFlags(fs, args.provider) + if err != nil { + return err + } + if wasClusterWideProxyReceived() { args.ccs.Enabled = true args.existingVPC.Enabled = true @@ -912,6 +918,22 @@ func promptName(argv []string) error { return fmt.Errorf("A cluster name must be specified") } +// promptProvider reads or prompts for the provider +func promptProvider(fs *pflag.FlagSet, connection *sdk.Connection) error { + // Only offer the 2 providers known to support OSD now; + // but don't validate if set, to not block `ocm` CLI from creating clusters on future providers. + providers, _ := osdProviderOptions(connection) + + err := arguments.PromptOneOf(fs, "provider", providers) + if err != nil { + return err + } + if args.provider == "" { + return fmt.Errorf("A provider must be specified") + } + return nil +} + func promptClusterWideProxy() error { var err error if args.existingVPC.Enabled && !wasClusterWideProxyReceived() && args.interactive { diff --git a/pkg/arguments/arguments.go b/pkg/arguments/arguments.go index 7b12e1f6..d8f2c312 100644 --- a/pkg/arguments/arguments.go +++ b/pkg/arguments/arguments.go @@ -157,6 +157,52 @@ func CheckIgnoredCCSFlags(ccs cluster.CCS) error { return nil } +// CheckIgnoredProviderFlags errors if provider-specific flags were used without the corresponding provider. +func CheckIgnoredProviderFlags(fs *pflag.FlagSet, provider string) error { + gcpExclusiveFlags := []string{ + "marketplace-gcp-terms", + "psc-subnet", + "secure-boot-for-shielded-vms", + "service-account-file", + "vpc-name", + "vpc-project-id", + "wif-config", + } + awsExclusiveFlags := []string{ + "aws-account-id", + "aws-access-key-id", + "aws-secret-access-key", + "additional-compute-security-group-ids", + "additional-infra-security-group-ids", + "additional-control-plane-security-group-ids", + "additional-trust-bundle-file", + "subnet-ids", + } + + bad := []string{} + if provider != cluster.ProviderGCP { + for _, flag := range gcpExclusiveFlags { + if fs.Changed(flag) { + bad = append(bad, flag) + } + } + } + if provider != cluster.ProviderAWS { + for _, flag := range awsExclusiveFlags { + if fs.Changed(flag) { + bad = append(bad, flag) + } + } + } + if len(bad) == 1 { + return fmt.Errorf("%s flag is meaningless for chosen provider", bad[0]) + } else if len(bad) > 1 { + return fmt.Errorf("%s flags are meaningless for chosen provider", + strings.Join(bad, ", ")) + } + return nil +} + const ( additionalComputeSecurityGroupIdsFlag = "additional-compute-security-group-ids" additionalInfraSecurityGroupIdsFlag = "additional-infra-security-group-ids" @@ -344,8 +390,8 @@ func AddProviderFlag(fs *pflag.FlagSet, value *string) { fs.StringVar( value, "provider", - "aws", - "The cloud provider to create the cluster on", + "", + "The cloud provider to create the cluster on. Supported options are [aws gcp]", ) SetQuestion(fs, "provider", "Cloud provider:") }