Skip to content

Commit

Permalink
Merge pull request clusterpedia-io#597 from nekomeowww/dev/fix-sqli
Browse files Browse the repository at this point in the history
feat: added new parameters for parameterized query
  • Loading branch information
Iceber authored Nov 27, 2023
2 parents d485cb1 + e19c45f commit f3e50db
Show file tree
Hide file tree
Showing 12 changed files with 5,051 additions and 28 deletions.
13 changes: 12 additions & 1 deletion pkg/storage/internalstorage/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,16 @@ const (
// owner: @cleverhu
// alpha: v0.3.0
AllowRawSQLQuery featuregate.Feature = "AllowRawSQLQuery"

// AllowParameterizedSQLQuery is a feature gate for the apiserver to allow querying by the parameterized SQL
// for better defense against SQL injection.
//
// Use either single whereSQLStatement field, a pair of whereSQLStatement with whereSQLParam, or
// whereSQLStatement with whereSQLJSONParams to pass the SQL it self and parameters.
//
// owner: @nekomeowww
// alpha: v0.8.0
AllowParameterizedSQLQuery featuregate.Feature = "AllowParameterizedSQLQuery"
)

func init() {
Expand All @@ -21,5 +31,6 @@ func init() {
// defaultInternalStorageFeatureGates consists of all known custom internalstorage feature keys.
// To add a new feature, define a key for it above and add it here.
var defaultInternalStorageFeatureGates = map[featuregate.Feature]featuregate.FeatureSpec{
AllowRawSQLQuery: {Default: false, PreRelease: featuregate.Alpha},
AllowRawSQLQuery: {Default: false, PreRelease: featuregate.Alpha},
AllowParameterizedSQLQuery: {Default: false, PreRelease: featuregate.Alpha},
}
56 changes: 44 additions & 12 deletions pkg/storage/internalstorage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package internalstorage

import (
"fmt"
"os"
"testing"

"github.com/DATA-DOG/go-sqlmock"
gmysql "gorm.io/driver/mysql"
Expand All @@ -10,34 +12,64 @@ import (
)

var (
postgresDB *gorm.DB
postgresDB *gorm.DB
postgresDBMock sqlmock.Sqlmock

mysqlVersions = []string{"8.0.27", "5.7.22"}
mysqlDBs = make(map[string]*gorm.DB, 2)
mysqlDBMocks = make(map[string]sqlmock.Sqlmock, 2)
)

func init() {
db, _, err := sqlmock.New()
func newMockedPostgresDB() (*gorm.DB, sqlmock.Sqlmock, error) {
mockedDB, mock, err := sqlmock.New()
if err != nil {
panic(fmt.Sprintf("sqlmock.New() failed: %v", err))
return nil, nil, fmt.Errorf("sqlmock.New() failed: %w", err)
}

postgresDB, err = gorm.Open(gpostgres.New(gpostgres.Config{Conn: db}))
gormDB, err := gorm.Open(gpostgres.New(gpostgres.Config{Conn: mockedDB}))
if err != nil {
panic(fmt.Sprintf("init postgresDB failed: %v", err))
return nil, nil, fmt.Errorf("init postgresDB failed: %w", err)
}

for _, version := range mysqlVersions {
db, mock, err := sqlmock.New()
return gormDB, mock, nil
}

func newMockedMySQLDB(version string) (*gorm.DB, sqlmock.Sqlmock, error) {
mockedDB, mock, err := sqlmock.New()
if err != nil {
return nil, nil, fmt.Errorf("sqlmock.New() failed: %w", err)
}

mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow(version))

mysqlDB, err := gorm.Open(gmysql.New(gmysql.Config{Conn: mockedDB}))
if err != nil {
return nil, nil, fmt.Errorf("init mysqlDB(%s) failed: %w", version, err)
}

return mysqlDB, mock, nil
}

func TestMain(m *testing.M) {
{
mockedDB, mock, err := newMockedPostgresDB()
if err != nil {
panic(fmt.Sprintf("sqlmock.New() failed: %v", err))
panic(err)
}
mock.ExpectQuery("SELECT VERSION()").WillReturnRows(sqlmock.NewRows([]string{"VERSION()"}).AddRow(version))

mysqlDB, err := gorm.Open(gmysql.New(gmysql.Config{Conn: db}))
postgresDB = mockedDB
postgresDBMock = mock
}

for _, version := range mysqlVersions {
mysqlDB, mock, err := newMockedMySQLDB(version)
if err != nil {
panic(fmt.Sprintf("init mysqlDB(%s) failed: %v", version, err))
panic(err)
}

mysqlDBs[version] = mysqlDB
mysqlDBMocks[version] = mock
}

os.Exit(m.Run())
}
174 changes: 167 additions & 7 deletions pkg/storage/internalstorage/util.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package internalstorage

import (
"encoding/base64"
"fmt"
"net/url"
"strconv"
"strings"

Expand All @@ -10,6 +12,7 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime/schema"
"k8s.io/apimachinery/pkg/selection"
"k8s.io/apimachinery/pkg/util/json"
"k8s.io/apimachinery/pkg/util/validation/field"
utilfeature "k8s.io/apiserver/pkg/util/feature"

Expand All @@ -19,9 +22,165 @@ import (
const (
SearchLabelFuzzyName = "internalstorage.clusterpedia.io/fuzzy-name"

// Raw query
URLQueryWhereSQL = "whereSQL"
// Parameterized query
URLQueryFieldWhereSQLStatement = "whereSQLStatement"
URLQueryFieldWhereSQLParam = "whereSQLParam"
URLQueryFieldWhereSQLJSONParams = "whereSQLJSONParams"
)

type URLQueryWhereSQLParams struct {
// Raw query
WhereSQL string
// Parameterized query
WhereSQLStatement string
WhereSQLParams []string
WhereSQLJSONParams []any
}

// NewURLQueryWhereSQLParamsFromURLValues resolves parameters from passed in url.Values.
// A k8s.io/apimachinery/pkg/api/errors.StatusError will be returned if decoding or unmarshalling failed
// only when the value of "whereSQLJSONParams" is present.
//
// It recognizes the following query fields for parameters:
//
// "whereSQL"
// "whereSQLStatement"
// "whereSQLParam"
// "whereSQLJSONParams"
func NewURLQueryWhereSQLParamsFromURLValues(urlQuery url.Values) (URLQueryWhereSQLParams, error) {
var params URLQueryWhereSQLParams

whereClause, ok := urlQuery[URLQueryWhereSQL]
if ok && len(whereClause) > 0 {
params.WhereSQL = whereClause[0]
}

whereClauseStatement, ok := urlQuery[URLQueryFieldWhereSQLStatement]
if ok && len(whereClauseStatement) > 0 {
params.WhereSQLStatement = whereClauseStatement[0]
}

whereClauseParams, ok := urlQuery[URLQueryFieldWhereSQLParam]
if ok {
params.WhereSQLParams = whereClauseParams
}

whereClauseJSONParams, ok := urlQuery[URLQueryFieldWhereSQLJSONParams]
if ok && len(whereClauseJSONParams) > 0 {
decodedBytesContent, err := base64.StdEncoding.DecodeString(whereClauseJSONParams[0])
if err != nil {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLJSONParams),
whereClauseJSONParams[0],
fmt.Sprintf("failed to decode base64 string: %v", err),
),
},
)
}

params.WhereSQLJSONParams = make([]any, 0)
err = json.Unmarshal(decodedBytesContent, &params.WhereSQLJSONParams)
if err != nil {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLJSONParams),
whereClauseJSONParams[0],
fmt.Sprintf("failed to unmarshal decoded base64 string to JSON array: %v", err),
),
},
)
}
}

if (len(params.WhereSQLParams) > 0 || len(params.WhereSQLJSONParams) > 0) && params.WhereSQLStatement == "" {
return URLQueryWhereSQLParams{}, apierrors.NewInvalid(
schema.GroupKind{Group: internal.GroupName, Kind: "ListOptions"},
"urlQuery",
field.ErrorList{
field.Invalid(
field.NewPath(URLQueryFieldWhereSQLStatement),
whereClauseStatement,
fmt.Sprintf("required when either %s or %s was provided", URLQueryFieldWhereSQLParam, URLQueryFieldWhereSQLJSONParams),
),
},
)
}

return params, nil
}

func applyListOptionsURLQueryParameterizedQueryToWhereClause(query *gorm.DB, params URLQueryWhereSQLParams) *gorm.DB {
if params.WhereSQLStatement == "" {
return query
}

// If a string of numbers is passed in from SQL, the query will be taken as ID by default.
// If the SQL contains English letter, it will be passed in as column.

if len(params.WhereSQLJSONParams) > 0 {
return query.Where(params.WhereSQLStatement, params.WhereSQLJSONParams...)
}
if len(params.WhereSQLParams) > 0 {
anyParameters := make([]any, len(params.WhereSQLParams))

for i := range params.WhereSQLParams {
anyParameters[i] = params.WhereSQLParams[i]
}

return query.Where(params.WhereSQLStatement, anyParameters...)
}

return query.Where(params.WhereSQLStatement)
}

// applyListOptionsURLQueryToWhereClause applies the where sql related parameters from url query to the where clause of the query.
//
// By design, both the parameters of whereSQLStatement and whereSQL will be accepted and be part of the query in order when
// AllowRawSQLQuery feature gate is enabled, and only whereSQLStatement will be accepted and be part of the query when
// AllowParameterizedSQLQuery feature gate is enabled.
func applyListOptionsURLQueryToWhereClause(query *gorm.DB, urlValues url.Values, allowRawSQLQueryEnabled bool, allowParameterizedSQLQueryEnabled bool) (*gorm.DB, error) {
if !allowRawSQLQueryEnabled && !allowParameterizedSQLQueryEnabled {
return query, nil
}

urlQueryParams, err := NewURLQueryWhereSQLParamsFromURLValues(urlValues)
if err != nil {
return query, err
}

if allowRawSQLQueryEnabled {
// use parameterized query first if statement was provided
//
// since users will need to migrate from caller (business) side first and make their transition
// to the newly added feature gate AllowParameterizedSQLQuery step by step, therefore a compatible
// implementation is required here to allow the migration and transition from caller (business) side
// while the existing feature gates that enabled for Clusterpedia deployment can be left as untouched
// and keep working as expected.
if urlQueryParams.WhereSQLStatement != "" {
return applyListOptionsURLQueryParameterizedQueryToWhereClause(query, urlQueryParams), nil
}
// otherwise, fallbacks to raw query
if urlQueryParams.WhereSQL != "" {
return query.Where(urlQueryParams.WhereSQL), nil
}
}

if allowParameterizedSQLQueryEnabled && urlQueryParams.WhereSQLStatement != "" {
return applyListOptionsURLQueryParameterizedQueryToWhereClause(query, urlQueryParams), nil
}

return query, nil
}

func applyListOptionsToQuery(query *gorm.DB, opts *internal.ListOptions, applyFn func(query *gorm.DB, opts *internal.ListOptions) (*gorm.DB, error)) (int64, *int64, *gorm.DB, error) {
switch len(opts.ClusterNames) {
case 0:
Expand Down Expand Up @@ -55,13 +214,14 @@ func applyListOptionsToQuery(query *gorm.DB, opts *internal.ListOptions, applyFn
query = query.Where("created_at < ?", opts.Before.Time.UTC())
}

if utilfeature.DefaultMutableFeatureGate.Enabled(AllowRawSQLQuery) {
if len(opts.URLQuery[URLQueryWhereSQL]) > 0 {
// TODO: prevent SQL injection.
// If a string of numbers is passed in from SQL, the query will be taken as ID by default.
// If the SQL contains English letter, it will be passed in as column.
query = query.Where(opts.URLQuery[URLQueryWhereSQL][0])
}
query, err := applyListOptionsURLQueryToWhereClause(
query,
opts.URLQuery,
utilfeature.DefaultMutableFeatureGate.Enabled(AllowRawSQLQuery),
utilfeature.DefaultMutableFeatureGate.Enabled((AllowParameterizedSQLQuery)),
)
if err != nil {
return 0, nil, nil, err
}

if opts.LabelSelector != nil {
Expand Down
Loading

0 comments on commit f3e50db

Please sign in to comment.