Skip to content

Commit

Permalink
fix: Fix casing during reading of resource group (#1673)
Browse files Browse the repository at this point in the history
* fix: Fix casing of filter name for resource group

* fix: fix broken unit test
  • Loading branch information
zekisherif authored Nov 1, 2024
1 parent a8f6ca1 commit 1167cb9
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 25 deletions.
94 changes: 87 additions & 7 deletions api/resource_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
_ "embed"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -95,14 +96,52 @@ func (svc *ResourceGroupsService) List() (response ResourceGroupsResponse, err e
return rawResponse, err
}

err = sanitizeFieldsInRawResponseList(&rawResponse, &response)
if err != nil {
return rawResponse, err
}

return rawResponse, nil
}

func sanitizeFieldsInRawResponse(rawResponse *ResourceGroupResponse, response interface{}) error {
// update filters keys to match the query template
updateFiltersKeys(&rawResponse.Data)

j, err := json.Marshal(rawResponse)
if err != nil {
return err
}

return json.Unmarshal(j, &response)
}

func sanitizeFieldsInRawResponseList(rawResponse *ResourceGroupsResponse, response interface{}) error {
for i := range rawResponse.Data {
// update filters keys to match the query template
updateFiltersKeys(&rawResponse.Data[i])
}

j, err := json.Marshal(rawResponse)
if err != nil {
return err
}

return json.Unmarshal(j, &response)
}

func (svc *ResourceGroupsService) Create(group ResourceGroupData) (
response ResourceGroupResponse,
err error,
) {
err = svc.create(group, &response)
var rawResponse ResourceGroupResponse
err = svc.create(group, &rawResponse)
if err != nil {
return
}

err = sanitizeFieldsInRawResponse(&rawResponse, &response)

return
}

Expand All @@ -117,14 +156,58 @@ func (svc *ResourceGroupsService) Update(data *ResourceGroupData) (
guid := data.ID()
data.ResetResourceGUID()

err = svc.update(guid, data, &response)
var rawResponse ResourceGroupResponse
err = svc.update(guid, data, &rawResponse)

if err != nil {
return
}

err = sanitizeFieldsInRawResponse(&rawResponse, &response)

return
}

func collectFilterNames(children []*RGChild, filterNames map[string]string) {
for _, child := range children {
if child.FilterName != "" {
normalizedKey := strings.ReplaceAll(strings.ToLower(child.FilterName), "_", "")
filterNames[normalizedKey] = child.FilterName
}
if len(child.Children) > 0 {
collectFilterNames(child.Children, filterNames)
}
}
}

/*
updateFiltersKeys updates the keys in the Filters map of ResourceGroupData to ensure they match the filter names
defined in the nested children of the query expression. This is necessary because JSON decoding/encoding can
convert keys to camel case, causing mismatches. The function normalizes the keys by removing underscores and
converting them to lower case, then compares them with the filter names. If a mismatch is found, the key is
updated to the value in RGExpression.Children
*/
func updateFiltersKeys(data *ResourceGroupData) {
if data.Query == nil || data.Query.Expression == nil {
return
}

filterNames := make(map[string]string)
collectFilterNames(data.Query.Expression.Children, filterNames)

updatedFilters := make(map[string]*RGFilter)
for key, value := range data.Query.Filters {
normalizedKey := strings.ReplaceAll(strings.ToLower(key), "_", "")
if _, exists := filterNames[normalizedKey]; exists {
updatedFilters[filterNames[normalizedKey]] = value
} else {
updatedFilters[key] = value
}
}

data.Query.Filters = updatedFilters
}

func (group *ResourceGroupData) ResetResourceGUID() {
group.ResourceGroupGuid = ""
group.UpdatedBy = ""
Expand All @@ -149,20 +232,17 @@ func (svc *ResourceGroupsService) Delete(guid string) error {

func (svc *ResourceGroupsService) Get(guid string, response interface{}) error {
var rawResponse ResourceGroupResponse

err := svc.get(guid, &rawResponse)
if err != nil {
return err
}

j, err := json.Marshal(rawResponse)
err = sanitizeFieldsInRawResponse(&rawResponse, response)
if err != nil {
return err
}

err = json.Unmarshal(j, &response)
if err != nil {
return err
}
return nil
}

Expand Down
108 changes: 108 additions & 0 deletions api/resource_groups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package api_test

import (
"encoding/json"
"fmt"
"net/http"
"strings"
Expand Down Expand Up @@ -132,6 +133,113 @@ func TestResourceGroupGet(t *testing.T) {
})
}

func TestResourceGroupsGetCorrectlyParsersFilterNames(t *testing.T) {
var (
queryJson = `
{
"expression": {
"children": [
{
"filterName": "filter_account"
},
{
"filterName": "filter1"
},
{
"filterName": "filter2"
},
{
"children": [
{
"filterName": "team_Account"
}
],
"operator": "OR"
}
],
"operator": "AND"
},
"filters": {
"filter1": {
"field": "Resource Tag",
"key": "Hostname",
"operation": "INCLUDES",
"values": [
"*"
]
},
"filter2": {
"field": "Region",
"operation": "STARTS_WITH",
"values": [
"ap-south"
]
},
"filter_account": {
"field": "Account",
"operation": "EQUALS",
"values": [
"123456789012"
]
},
"team_Account": {
"field": "Account",
"operation": "EQUALS",
"values": [
"123456789012"
]
}
}
}
`
resourceGUID = intgguid.New()
vanillaType = "VANILLA"
apiPath = fmt.Sprintf("ResourceGroups/%s", resourceGUID)
vanillaGroup = singleVanillaResourceGroup(resourceGUID, vanillaType, queryJson)
fakeServer = lacework.MockServer()
)

fakeServer.MockToken("TOKEN")
defer fakeServer.Close()

fakeServer.MockAPI(apiPath,
func(w http.ResponseWriter, r *http.Request) {
if assert.Equal(t, "GET", r.Method, "Get() should be a GET method") {
fmt.Fprintf(w, generateResourceGroupResponse(vanillaGroup))
}
},
)

c, err := api.NewClient("test",
api.WithToken("TOKEN"),
api.WithURL(fakeServer.URL()),
)

assert.Nil(t, err)

t.Run("when resource groups GET is called. Filter keys are correctly parsed", func(t *testing.T) {
var response api.ResourceGroupResponse
err := c.V2.ResourceGroups.Get(resourceGUID, &response)
assert.Nil(t, err)
if assert.NotNil(t, response) {
assert.Equal(t, resourceGUID, response.Data.ResourceGroupGuid)
assert.Equal(t, "group_name", response.Data.Name)
assert.Equal(t, "VANILLA", response.Data.Type)
// assert that the filter names in queryjson matach RGQuery
var expectedQuery api.RGQuery
err = json.Unmarshal([]byte(queryJson), &expectedQuery)
assert.Nil(t, err)

assert.NotNil(t, response.Data.Query.Filters["filter_account"])
assert.Equal(t, expectedQuery.Filters["filter_account"], response.Data.Query.Filters["filter_account"])

assert.NotNil(t, response.Data.Query.Filters["team_Account"])
assert.Equal(t, expectedQuery.Filters["team_Account"], response.Data.Query.Filters["team_Account"])
}
})
}

func TestResourceGroupsDelete(t *testing.T) {
var (
resourceGUID = intgguid.New()
Expand Down
19 changes: 1 addition & 18 deletions cli/cmd/resource_groups.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,24 +208,7 @@ func promptCreateResourceGroup() error {
return err
}

switch group {
case "AWS":
return createResourceGroup("AWS")
case "AZURE":
return createResourceGroup("AZURE")
case "GCP":
return createResourceGroup("GCP")
case "CONTAINER":
return createResourceGroup("CONTAINER")
case "MACHINE":
return createResourceGroup("MACHINE")
case "OCI":
return createResourceGroup("OCI")
case "KUBERNETES":
return createResourceGroup("KUBERNETES")
default:
return errors.New("unknown resource group type")
}
return createResourceGroup(group)
}

func init() {
Expand Down

0 comments on commit 1167cb9

Please sign in to comment.