Skip to content

Commit

Permalink
Require provider, do not default to AWS, and check for provider-speci…
Browse files Browse the repository at this point in the history
…fic flags
  • Loading branch information
JakobGray committed Oct 17, 2024
1 parent 8b70707 commit cb6f76c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 6 deletions.
30 changes: 26 additions & 4 deletions cmd/ocm/create/cluster/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -615,16 +615,17 @@ func preRun(cmd *cobra.Command, argv []string) error {
return err
}

// Only offer the 2 providers known to support OSD now;
// but don't validate if set, to not block `ocm` CLI from creating clusters on future providers.
providers, _ := osdProviderOptions(connection)
// If marketplace-gcp subscription type is used, provider can only be GCP
gcpBillingModel, _ := billing.GetBillingModel(connection, billing.MarketplaceGcpSubscriptionType)
gcpSubscriptionTypeTemplate := subscriptionTypeOption(gcpBillingModel.ID(), gcpBillingModel.Description())
isGcpMarketplace :=
parseSubscriptionType(args.subscriptionType) == parseSubscriptionType(gcpSubscriptionTypeTemplate.Value)

if isGcpMarketplace {
if args.provider != c.ProviderGCP && args.provider != "" {
return fmt.Errorf("Provider must be set to %s when using %s subscription type",
c.ProviderGCP, billing.MarketplaceGcpSubscriptionType)
}
fmt.Println("setting provider to", c.ProviderGCP)
args.provider = c.ProviderGCP
fmt.Println("setting ccs to 'true'")
Expand All @@ -643,12 +644,17 @@ func preRun(cmd *cobra.Command, argv []string) error {
return fmt.Errorf(gcpTermsAgreementNonInteractiveError)
}
} else {
err = arguments.PromptOneOf(fs, "provider", providers)
err = promptProvider(fs, connection)
if err != nil {
return err
}
}

err = arguments.CheckIgnoredProviderFlags(fs, args.provider)
if err != nil {
return err
}

if wasClusterWideProxyReceived() {
args.ccs.Enabled = true
args.existingVPC.Enabled = true
Expand Down Expand Up @@ -912,6 +918,22 @@ func promptName(argv []string) error {
return fmt.Errorf("A cluster name must be specified")
}

// promptProvider reads or prompts for the provider
func promptProvider(fs *pflag.FlagSet, connection *sdk.Connection) error {
// Only offer the 2 providers known to support OSD now;
// but don't validate if set, to not block `ocm` CLI from creating clusters on future providers.
providers, _ := osdProviderOptions(connection)

err := arguments.PromptOneOf(fs, "provider", providers)
if err != nil {
return err
}
if args.provider == "" {
return fmt.Errorf("A provider must be specified")
}
return nil
}

func promptClusterWideProxy() error {
var err error
if args.existingVPC.Enabled && !wasClusterWideProxyReceived() && args.interactive {
Expand Down
50 changes: 48 additions & 2 deletions pkg/arguments/arguments.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,52 @@ func CheckIgnoredCCSFlags(ccs cluster.CCS) error {
return nil
}

// CheckIgnoredProviderFlags errors if provider-specific flags were used without the corresponding provider.
func CheckIgnoredProviderFlags(fs *pflag.FlagSet, provider string) error {
gcpExclusiveFlags := []string{
"marketplace-gcp-terms",
"psc-subnet",
"secure-boot-for-shielded-vms",
"service-account-file",
"vpc-name",
"vpc-project-id",
"wif-config",
}
awsExclusiveFlags := []string{
"aws-account-id",
"aws-access-key-id",
"aws-secret-access-key",
"additional-compute-security-group-ids",
"additional-infra-security-group-ids",
"additional-control-plane-security-group-ids",
"additional-trust-bundle-file",
"subnet-ids",
}

bad := []string{}
if provider != cluster.ProviderGCP {
for _, flag := range gcpExclusiveFlags {
if fs.Changed(flag) {
bad = append(bad, flag)
}
}
}
if provider != cluster.ProviderAWS {
for _, flag := range awsExclusiveFlags {
if fs.Changed(flag) {
bad = append(bad, flag)
}
}
}
if len(bad) == 1 {
return fmt.Errorf("%s flag is meaningless using chosen provider", bad[0])
} else if len(bad) > 1 {
return fmt.Errorf("%s flags are meaningless using chosen provider",
strings.Join(bad, ", "))
}
return nil
}

const (
additionalComputeSecurityGroupIdsFlag = "additional-compute-security-group-ids"
additionalInfraSecurityGroupIdsFlag = "additional-infra-security-group-ids"
Expand Down Expand Up @@ -344,8 +390,8 @@ func AddProviderFlag(fs *pflag.FlagSet, value *string) {
fs.StringVar(
value,
"provider",
"aws",
"The cloud provider to create the cluster on",
"",
"The cloud provider to create the cluster on. Supported options are [aws gcp]",
)
SetQuestion(fs, "provider", "Cloud provider:")
}
Expand Down

0 comments on commit cb6f76c

Please sign in to comment.