Skip to content

Commit

Permalink
fix: unsupported operations by driver
Browse files Browse the repository at this point in the history
Signed-off-by: Azanul <[email protected]>
  • Loading branch information
Azanul committed Jun 12, 2024
1 parent 635f03b commit 75e0f80
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 89 deletions.
14 changes: 4 additions & 10 deletions controller/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,13 @@ func (ctrl *Controller) CountResources(c context.Context, provider, name string)
}

func (ctrl *Controller) InsertAccount(c context.Context, account models.Account) (lastId int64, err error) {
result, err := ctrl.repo.HandleQuery(c, repository.InsertKey, &account, nil, "")
if err != nil {
return
}
return result.LastInsertId()
lastId, err = ctrl.repo.HandleQuery(c, repository.InsertKey, &account, nil, "")
return
}

func (ctrl *Controller) RescanAccount(c context.Context, account *models.Account, accountId string) (rows int64, err error) {
res, err := ctrl.repo.HandleQuery(c, repository.ReScanAccountKey, account, [][3]string{{"id", "=", accountId}, {"status", "=", "CONNECTED"}}, "")
if err != nil {
return 0, err
}
return res.RowsAffected()
rows, err = ctrl.repo.HandleQuery(c, repository.ReScanAccountKey, account, [][3]string{{"id", "=", accountId}, {"status", "=", "CONNECTED"}}, "")
return
}

func (ctrl *Controller) DeleteAccount(c context.Context, accountId string) (err error) {
Expand Down
9 changes: 3 additions & 6 deletions controller/alerts.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@ import (
)

func (ctrl *Controller) InsertAlert(c context.Context, alert models.Alert) (alertId int64, err error) {
result, err := ctrl.repo.HandleQuery(c, repository.InsertKey, &alert, nil, "")
if err != nil {
return
}
return result.LastInsertId()
alertId, err = ctrl.repo.HandleQuery(c, repository.InsertKey, &alert, nil, "")
return
}

func (ctrl *Controller) UpdateAlert(c context.Context, alert models.Alert, alertId string) (err error) {
_, err = ctrl.repo.HandleQuery(c, repository.UpdateAlertKey, &alert, [][3]string{{"id", "=", alertId}},"")
_, err = ctrl.repo.HandleQuery(c, repository.UpdateAlertKey, &alert, [][3]string{{"id", "=", alertId}}, "")
return
}

Expand Down
3 changes: 1 addition & 2 deletions controller/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controller

import (
"context"
"database/sql"

"github.com/tailwarden/komiser/models"
)
Expand Down Expand Up @@ -32,7 +31,7 @@ type accountOutput struct {
}

type Repository interface {
HandleQuery(context.Context, string, interface{}, [][3]string, string) (sql.Result, error)
HandleQuery(context.Context, string, interface{}, [][3]string, string) (int64, error)
GenerateFilterQuery(view models.View, queryTitle string, arguments []int64, queryParameter string) ([]string, error)
}

Expand Down
7 changes: 2 additions & 5 deletions controller/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@ func (ctrl *Controller) ListViews(c context.Context) (views []models.View, err e
}

func (ctrl *Controller) InsertView(c context.Context, view models.View) (viewId int64, err error) {
result, err := ctrl.repo.HandleQuery(c, repository.InsertKey, &view, nil, "")
if err != nil {
return
}
return result.LastInsertId()
viewId, err = ctrl.repo.HandleQuery(c, repository.InsertKey, &view, nil, "")
return
}

func (ctrl *Controller) UpdateView(c context.Context, view models.View, viewId string) (err error) {
Expand Down
35 changes: 24 additions & 11 deletions repository/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package repository

import (
"context"
"database/sql"
"fmt"

"github.com/uptrace/bun"
Expand Down Expand Up @@ -78,27 +77,35 @@ func ExecuteSelect(ctx context.Context, db *bun.DB, schema interface{}, conditio
return q.Scan(ctx, schema)
}

func ExecuteInsert(ctx context.Context, db *bun.DB, schema interface{}) (sql.Result, error) {
resp, err := db.NewInsert().Model(schema).Exec(ctx)
func ExecuteInsert(ctx context.Context, db *bun.DB, schema interface{}) (id int64, err error) {
res, err := db.NewInsert().Model(schema).Returning("id").Exec(ctx, &id)
if err != nil {
return resp, err
_id, err := res.LastInsertId()
if err != nil {
id = _id
}
}
return resp, nil
return
}

func ExecuteDelete(ctx context.Context, db *bun.DB, schema interface{}, conditions [][3]string) (sql.Result, error) {
func ExecuteDelete(ctx context.Context, db *bun.DB, schema interface{}, conditions [][3]string) (int64, error) {
q := db.NewDelete().Model(schema)

q = addWhereClause(q.QueryBuilder(), conditions).Unwrap().(*bun.DeleteQuery)

resp, err := q.Exec(ctx)
if err != nil {
return resp, err
return 0, err
}
return resp, nil

rowsAffected, err := resp.RowsAffected()
if err != nil {
return 0, err
}
return rowsAffected, nil
}

func ExecuteUpdate(ctx context.Context, db *bun.DB, schema interface{}, columns []string, conditions [][3]string) (sql.Result, error) {
func ExecuteUpdate(ctx context.Context, db *bun.DB, schema interface{}, columns []string, conditions [][3]string) (int64, error) {
q := db.NewUpdate().Model(schema).Column(columns...)

q = addWhereClause(q.QueryBuilder(), conditions).Unwrap().(*bun.UpdateQuery)
Expand All @@ -107,9 +114,15 @@ func ExecuteUpdate(ctx context.Context, db *bun.DB, schema interface{}, columns

resp, err := q.Exec(ctx)
if err != nil {
return resp, err
return 0, err
}
return resp, nil

rowsAffected, err := resp.RowsAffected()
if err != nil {
return 0, err
}

return rowsAffected, nil
}

func addWhereClause(query bun.QueryBuilder, conditions [][3]string) bun.QueryBuilder {
Expand Down
47 changes: 22 additions & 25 deletions repository/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package postgres

import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
Expand All @@ -15,7 +14,7 @@ import (
)

type Repository struct {
mu sync.RWMutex
mu sync.RWMutex
db *bun.DB
queries map[string]repository.Object
}
Expand Down Expand Up @@ -98,8 +97,8 @@ var Queries = map[string]repository.Object{
Type: repository.RAW,
Query: "SELECT DISTINCT(account) FROM resources",
},
repository.ListResourceWithFilter : {
Type: repository.RAW,
repository.ListResourceWithFilter: {
Type: repository.RAW,
Query: "",
Params: []string{
"(name LIKE '%%%s%%' OR region LIKE '%%%s%%' OR service LIKE '%%%s%%' OR provider LIKE '%%%s%%' OR account LIKE '%%%s%%' OR (value->>'key' LIKE '%%%s%%') OR (value->>'value' LIKE '%%%s%%'))",
Expand All @@ -110,15 +109,15 @@ var Queries = map[string]repository.Object{
"SELECT * FROM resources WHERE %s AND id NOT IN (%s) ORDER BY id LIMIT %d OFFSET %d",
},
},
repository.ListRelationWithFilter : {
Type: repository.RAW,
repository.ListRelationWithFilter: {
Type: repository.RAW,
Query: "",
Params: []string{
"SELECT DISTINCT resources.resource_id, resources.provider, resources.name, resources.service, resources.relations FROM resources WHERE (jsonb_array_length(relations) > 0)",
},
},
repository.ListStatsWithFilter : {
Type: repository.RAW,
repository.ListStatsWithFilter: {
Type: repository.RAW,
Query: "",
Params: []string{
"SELECT COUNT(*) as total FROM (SELECT DISTINCT region FROM resources CROSS JOIN jsonb_array_elements(tags) AS res WHERE %s) AS temp",
Expand All @@ -131,23 +130,21 @@ var Queries = map[string]repository.Object{
},
}

func (repo *Repository) HandleQuery(ctx context.Context, queryTitle string, schema interface{}, conditions [][3]string, rawQuery string) (sql.Result, error) {
var resp sql.Result
var err error
func (repo *Repository) HandleQuery(ctx context.Context, queryTitle string, schema interface{}, conditions [][3]string, rawQuery string) (resp int64, err error) {
repo.mu.RLock()
query, ok := Queries[queryTitle]
repo.mu.RUnlock()
if !ok {
return nil, repository.ErrQueryNotFound
return 0, repository.ErrQueryNotFound
}
switch query.Type {
case repository.RAW:
if rawQuery != "" && query.Query == "" {
err = repository.ExecuteRaw(ctx, repo.db, rawQuery, schema, conditions)
err = repository.ExecuteRaw(ctx, repo.db, rawQuery, schema, conditions)
} else {
err = repository.ExecuteRaw(ctx, repo.db, query.Query, schema, conditions)
}

case repository.SELECT:
err = repository.ExecuteSelect(ctx, repo.db, schema, conditions)

Expand All @@ -173,13 +170,13 @@ func (repo *Repository) GenerateFilterQuery(view models.View, queryTitle string,
if err != nil {
return nil, err
}
whereQueries = append(whereQueries, query)
whereQueries = append(whereQueries, query)
case "cost":
query, err := generateCostFilterQuery(filter)
if err != nil {
return nil, err
}
whereQueries = append(whereQueries, query)
whereQueries = append(whereQueries, query)
case "relation":
query, err := generateRelationFilterQuery(filter)
if err != nil {
Expand Down Expand Up @@ -210,10 +207,10 @@ func (repo *Repository) GenerateFilterQuery(view models.View, queryTitle string,
return queryBuilderWithFilter(view, queryTitle, arguments, queryParameter, filterWithTags, whereClause), nil
}

func queryBuilderWithFilter(view models.View, queryTitle string, arguments []int64, query string, withTags bool, whereClause string) []string {
func queryBuilderWithFilter(view models.View, queryTitle string, arguments []int64, query string, withTags bool, whereClause string) []string {
searchQuery := []string{}
limit, skip := arguments[0], arguments[1]
if len(view.Filters) == 0 {
if len(view.Filters) == 0 {
switch queryTitle {
case repository.ListRelationWithFilter:
return append(searchQuery, Queries[queryTitle].Params[0])
Expand All @@ -228,12 +225,12 @@ func queryBuilderWithFilter(view models.View, queryTitle string, arguments []in
return append(searchQuery, tempQuery)
}
} else if queryTitle == repository.ListRelationWithFilter {
return append(searchQuery, Queries[queryTitle].Params[0] + " AND " + whereClause)
return append(searchQuery, Queries[queryTitle].Params[0]+" AND "+whereClause)
}

if withTags {
if queryTitle == repository.ListStatsWithFilter {
for i := 0; i<3; i++ {
for i := 0; i < 3; i++ {
searchQuery = append(searchQuery, fmt.Sprintf(Queries[queryTitle].Params[i], whereClause))
}
return searchQuery
Expand All @@ -246,7 +243,7 @@ func queryBuilderWithFilter(view models.View, queryTitle string, arguments []in
return append(searchQuery, tempQuery)
} else {
if queryTitle == repository.ListStatsWithFilter {
for i := 3; i<6; i++ {
for i := 3; i < 6; i++ {
searchQuery = append(searchQuery, fmt.Sprintf(Queries[queryTitle].Params[i], whereClause))
}
return searchQuery
Expand Down Expand Up @@ -347,18 +344,18 @@ func generateCostFilterQuery(filter models.Filter) (string, error) {
if err != nil {
return "", err
}
return fmt.Sprintf("(cost >= %f AND cost <= %f)", min, max), nil
return fmt.Sprintf("(cost >= %f AND cost <= %f)", min, max), nil
case "GREATER_THAN":
cost, err := strconv.ParseFloat(filter.Values[0], 64)
if err != nil {
return "", err
}
}
return fmt.Sprintf("(cost > %f)", cost), err
case "LESS_THAN":
cost, err := strconv.ParseFloat(filter.Values[0], 64)
if err != nil {
return "", err
}
}
return fmt.Sprintf("(cost < %f)", cost), nil
default:
return "", fmt.Errorf("unsupported operator for cost field: %s", filter.Operator)
Expand Down Expand Up @@ -388,4 +385,4 @@ func generateRelationFilterQuery(filter models.Filter) (string, error) {
default:
return "", fmt.Errorf("unsupported operator: %s", filter.Operator)
}
}
}
Loading

0 comments on commit 75e0f80

Please sign in to comment.