Skip to content

Commit

Permalink
bu fixes for nil deref (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickb-orca authored Nov 15, 2024
1 parent 399fa2d commit 1038b1a
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 79 deletions.
10 changes: 5 additions & 5 deletions orcasecurity/api_client/business_unit.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@ type BusinessUnitFilter struct {
}

type BusinessUnitShiftLeftFilter struct {
ShiftLeftProjects []string `json:"shiftleft_project_id"`
ShiftLeftProjects []string `json:"shiftleft_project_id,omitempty"`
}

type BusinessUnit struct {
ID string `json:"filter_id"`
Name string `json:"name"`
Filter BusinessUnitFilter `json:"filter_data,omitempty"`
SLFilter BusinessUnitShiftLeftFilter `json:"shiftleft_filter_data"`
ID string `json:"filter_id"`
Name string `json:"name"`
Filter BusinessUnitFilter `json:"filter_data,omitempty"`
ShiftLeftFilter *BusinessUnitShiftLeftFilter `json:"shiftleft_filter_data,omitempty"`
}

type businessUnitAPIResponseType struct {
Expand Down
131 changes: 66 additions & 65 deletions orcasecurity/business_unit/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,16 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
return
}

if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) == 0 {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)
if plan.Filter != nil && plan.Filter.CloudProvider != nil {
filter, filterDiags := generateCloudProviderFilter(plan.Filter)
diags.Append(filterDiags...)

//EmptyBuslf := api_client.BusinessUnitShiftLeftFilter{}

createReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
SLFilter: slFilter,
Name: plan.Name.ValueString(),
Filter: filter,
ShiftLeftFilter: nil,
}

instance, err := r.apiClient.CreateBusinessUnit(createReq)
Expand All @@ -233,20 +238,13 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) > 0 {
} else if plan.Filter != nil && plan.Filter.CloudAccounts != nil && (plan.ShiftLeftFilter == nil || plan.ShiftLeftFilter.ShiftLeftProjects == nil) {
filter, filterDiags := generateCloudAccountsFilter(plan.Filter)
diags.Append(filterDiags...)

createReq := api_client.BusinessUnit{}

if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)
createReq = api_client.BusinessUnit{
Name: plan.Name.ValueString(),
SLFilter: slFilter,
Filter: filter,
}

createReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
Filter: filter,
}

instance, err := r.apiClient.CreateBusinessUnit(createReq)
Expand All @@ -265,13 +263,11 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.Filter.CloudProvider) > 0 {
filter, filterDiags := generateCloudProviderFilter(plan.Filter)
diags.Append(filterDiags...)

} else if plan.ShiftLeftFilter != nil && plan.ShiftLeftFilter.ShiftLeftProjects != nil && (plan.Filter == nil || plan.Filter.CloudAccounts == nil) {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)
createReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
Filter: filter,
Name: plan.Name.ValueString(),
ShiftLeftFilter: &slFilter,
}

instance, err := r.apiClient.CreateBusinessUnit(createReq)
Expand All @@ -290,13 +286,20 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.Filter.CustomTags) > 0 {
filter, filterDiags := generateCustomTagsFilter(plan.Filter)
} else if plan.ShiftLeftFilter != nil && plan.ShiftLeftFilter.ShiftLeftProjects != nil && plan.Filter != nil && plan.Filter.CloudAccounts != nil {
filter, filterDiags := generateCloudAccountsFilter(plan.Filter)
diags.Append(filterDiags...)

createReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
Filter: filter,
createReq := api_client.BusinessUnit{}

if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)
createReq = api_client.BusinessUnit{
Name: plan.Name.ValueString(),
ShiftLeftFilter: &slFilter,
Filter: filter,
}

}

instance, err := r.apiClient.CreateBusinessUnit(createReq)
Expand All @@ -315,8 +318,8 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.Filter.AccountTags) > 0 {
filter, filterDiags := generateAccountTagsFilter(plan.Filter)
} else if plan.Filter != nil && plan.Filter.CustomTags != nil {
filter, filterDiags := generateCustomTagsFilter(plan.Filter)
diags.Append(filterDiags...)

createReq := api_client.BusinessUnit{
Expand All @@ -340,8 +343,8 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.Filter.InventoryTags) > 0 {
filter, filterDiags := generateInventoryTagsFilter(plan.Filter)
} else if plan.Filter != nil && plan.Filter.AccountTags != nil {
filter, filterDiags := generateAccountTagsFilter(plan.Filter)
diags.Append(filterDiags...)

createReq := api_client.BusinessUnit{
Expand All @@ -365,8 +368,8 @@ func (r *businessUnitResource) Create(ctx context.Context, req resource.CreateRe
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.Filter.CloudAccounts) > 0 && len(plan.ShiftLeftFilter.ShiftLeftProjects) == 0 {
filter, filterDiags := generateCloudAccountsFilter(plan.Filter)
} else if plan.Filter != nil && plan.Filter.InventoryTags != nil {
filter, filterDiags := generateInventoryTagsFilter(plan.Filter)
diags.Append(filterDiags...)

createReq := api_client.BusinessUnit{
Expand Down Expand Up @@ -451,39 +454,44 @@ func (r *businessUnitResource) Update(ctx context.Context, req resource.UpdateRe
return
}

if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) == 0 {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)
if len(plan.Filter.CloudProvider) > 0 {
filter, filterDiags := generateCloudProviderFilter(plan.Filter)
diags.Append(filterDiags...)

updateReq := api_client.BusinessUnit{
ID: plan.ID.ValueString(),
Name: plan.Name.ValueString(),
SLFilter: slFilter,
Name: plan.Name.ValueString(),
Filter: filter,
}

instance, err := r.apiClient.UpdateBusinessUnit(updateReq.ID, updateReq)

_, err := r.apiClient.UpdateBusinessUnit(plan.ID.ValueString(), updateReq)
if err != nil {
resp.Diagnostics.AddError(
"Error updating business unit",
"Could not update business unit, unexpected error: "+err.Error(),
)
return
}
plan.ID = types.StringValue(instance.ID)

diags = resp.State.Set(ctx, plan)
_, err = r.apiClient.GetBusinessUnit(plan.ID.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error reading business unit",
"Could not read Business Unit ID: "+plan.ID.ValueString()+": "+err.Error(),
)
return
}

diags = resp.State.Set(ctx, &plan)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
}
} else if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) > 0 {
filter, filterDiags := generateCloudAccountsFilter(plan.Filter)
diags.Append(filterDiags...)
} else if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) == 0 {
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)

updateReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
SLFilter: slFilter,
Filter: filter,
ID: plan.ID.ValueString(),
Name: plan.Name.ValueString(),
ShiftLeftFilter: &slFilter,
}

instance, err := r.apiClient.UpdateBusinessUnit(updateReq.ID, updateReq)
Expand All @@ -502,36 +510,29 @@ func (r *businessUnitResource) Update(ctx context.Context, req resource.UpdateRe
if resp.Diagnostics.HasError() {
return
}
}

if len(plan.Filter.CloudProvider) > 0 {
filter, filterDiags := generateCloudProviderFilter(plan.Filter)
} else if len(plan.ShiftLeftFilter.ShiftLeftProjects) > 0 && len(plan.Filter.CloudAccounts) > 0 {
filter, filterDiags := generateCloudAccountsFilter(plan.Filter)
diags.Append(filterDiags...)
slFilter, _ := generateShiftLeftProjectFilter(plan.ShiftLeftFilter)

updateReq := api_client.BusinessUnit{
Name: plan.Name.ValueString(),
Filter: filter,
Name: plan.Name.ValueString(),
ShiftLeftFilter: &slFilter,
Filter: filter,
}

_, err := r.apiClient.UpdateBusinessUnit(plan.ID.ValueString(), updateReq)
instance, err := r.apiClient.UpdateBusinessUnit(updateReq.ID, updateReq)

if err != nil {
resp.Diagnostics.AddError(
"Error updating business unit",
"Could not update business unit, unexpected error: "+err.Error(),
)
return
}
plan.ID = types.StringValue(instance.ID)

_, err = r.apiClient.GetBusinessUnit(plan.ID.ValueString())
if err != nil {
resp.Diagnostics.AddError(
"Error reading business unit",
"Could not read Business Unit ID: "+plan.ID.ValueString()+": "+err.Error(),
)
return
}

diags = resp.State.Set(ctx, &plan)
diags = resp.State.Set(ctx, plan)
resp.Diagnostics.Append(diags...)
if resp.Diagnostics.HasError() {
return
Expand Down
56 changes: 52 additions & 4 deletions orcasecurity/business_unit/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ func TestAccBusinessUnitResource_Basic(t *testing.T) {
{
Config: orcasecurity.TestProviderConfig + `
resource "orcasecurity_business_unit" "business_unit_for_aws" {
name = "AWS"
name = "AWSBU"
filter_data = {
cloud_provider = ["aws"]
}
}
`,
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_aws", "name", "AWS"),
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_aws", "name", "AWSBU"),
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_aws", "filter_data.cloud_provider[0]", "aws"),
),
},
Expand All @@ -36,17 +36,65 @@ resource "orcasecurity_business_unit" "business_unit_for_aws" {
{
Config: orcasecurity.TestProviderConfig + `
resource "orcasecurity_business_unit" "business_unit_for_azure" {
name = "Azure"
name = "AzureBU"
filter_data = {
cloud_provider = ["azure"]
}
}
`,
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_azure", "name", "Azure"),
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_azure", "name", "AzureBU"),
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_azure", "filter_data.cloud_provider[0]", "azure"),
),
},
/*{
Config: "", // Empty config forces destroy of all resources
},*/
},
})
}

func TestAccBusinessUnitResource_ShiftLeft(t *testing.T) {
resource.Test(t, resource.TestCase{
ProtoV6ProviderFactories: orcasecurity.TestAccProtoV6ProviderFactories,
Steps: []resource.TestStep{
// create
{
Config: orcasecurity.TestProviderConfig + `
resource "orcasecurity_business_unit" "business_unit_for_sl" {
name = "SL BU"
shiftleft_filter_data = {
shiftleft_project_id = ["577ba5de-3837-4db1-999f-dd2524e09e52"]
}
}
`,
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_sl", "shiftleft_filter_data.shiftleft_project_id[0]", "577ba5de-3837-4db1-999f-dd2524e09e52"),
),
},
// import
{
ResourceName: "orcasecurity_business_unit.business_unit_for_sl",
ImportState: true,
ImportStateVerify: true,
},
// update
{
Config: orcasecurity.TestProviderConfig + `
resource "orcasecurity_business_unit" "business_unit_for_sl" {
name = "AWS"
shiftleft_filter_data = {
shiftleft_project_id = ["577ba5de-3837-4db1-999f-dd2524e09e52"]
}
}
`,
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("orcasecurity_business_unit.business_unit_for_sl", "shiftleft_filter_data.shiftleft_project_id[0]", "577ba5de-3837-4db1-999f-dd2524e09e52"),
),
},
/*{
Config: "", // Empty config forces destroy of all resources
},*/
},
})
}
10 changes: 5 additions & 5 deletions orcasecurity/trusted_cloud_account/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@ func TestTrustedCloudAccountResource_Basic(t *testing.T) {
// create
{
Config: orcasecurity.TestProviderConfig + `
resource "orcasecurity_trusted_cloud_account" "account-1" {
account_name = "test44912"
description = "test2"
resource "orcasecurity_trusted_cloud_account" "Orca_TF_Provider_Acceptance_Test_Account" {
account_name = "Orca TF Provider Acceptance Test Account"
description = "Dummy Description"
cloud_provider = "aws"
cloud_provider_id = "12341234123445678912"
}
`,
Check: resource.ComposeAggregateTestCheckFunc(
resource.TestCheckResourceAttr("orcasecurity_trusted_cloud_account.account-1", "name", "test44912"),
resource.TestCheckResourceAttr("orcasecurity_trusted_cloud_account.account-1", "description", "test2"),
resource.TestCheckResourceAttr("orcasecurity_trusted_cloud_account.Orca_TF_Provider_Acceptance_Test_Account", "name", "Orca TF Provider Acceptance Test Account"),
resource.TestCheckResourceAttr("orcasecurity_trusted_cloud_account.Orca_TF_Provider_Acceptance_Test_Account", "description", "Dummy Description"),
),
},
// import
Expand Down

0 comments on commit 1038b1a

Please sign in to comment.