diff --git a/.mockery.yml b/.mockery.yml index fe1c9357ff..43d807b00f 100644 --- a/.mockery.yml +++ b/.mockery.yml @@ -63,7 +63,6 @@ packages: interfaces: SqlConnector: SqlDbContainer: - PgPoolContainer: github.com/nucleuscloud/neosync/backend/pkg/sqlmanager: interfaces: SqlDatabase: diff --git a/backend/gen/go/db/dbschemas/postgresql/db.go b/backend/gen/go/db/dbschemas/postgresql/db.go index 8a85444a59..d6e2e6996b 100644 --- a/backend/gen/go/db/dbschemas/postgresql/db.go +++ b/backend/gen/go/db/dbschemas/postgresql/db.go @@ -6,15 +6,14 @@ package pg_queries import ( "context" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" + "database/sql" ) type DBTX interface { - Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) - Query(context.Context, string, ...interface{}) (pgx.Rows, error) - QueryRow(context.Context, string, ...interface{}) pgx.Row + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row } func New() *Queries { diff --git a/backend/gen/go/db/dbschemas/postgresql/mock_DBTX.go b/backend/gen/go/db/dbschemas/postgresql/mock_DBTX.go index b4512028a2..4ef33a891c 100644 --- a/backend/gen/go/db/dbschemas/postgresql/mock_DBTX.go +++ b/backend/gen/go/db/dbschemas/postgresql/mock_DBTX.go @@ -4,11 +4,9 @@ package pg_queries import ( context "context" + sql "database/sql" - pgconn "github.com/jackc/pgx/v5/pgconn" mock "github.com/stretchr/testify/mock" - - pgx "github.com/jackc/pgx/v5" ) // MockDBTX is an autogenerated mock type for the DBTX type @@ -24,26 +22,28 @@ func (_m *MockDBTX) EXPECT() *MockDBTX_Expecter { return &MockDBTX_Expecter{mock: &_m.Mock} } -// Exec provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgconn.CommandTag, error) { +// ExecContext provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) ExecContext(_a0 context.Context, _a1 string, _a2 ...interface{}) (sql.Result, error) { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) ret := _m.Called(_ca...) if len(ret) == 0 { - panic("no return value specified for Exec") + panic("no return value specified for ExecContext") } - var r0 pgconn.CommandTag + var r0 sql.Result var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (sql.Result, error)); ok { return rf(_a0, _a1, _a2...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) sql.Result); ok { r0 = rf(_a0, _a1, _a2...) } else { - r0 = ret.Get(0).(pgconn.CommandTag) + if ret.Get(0) != nil { + r0 = ret.Get(0).(sql.Result) + } } if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { @@ -55,21 +55,21 @@ func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...interface{}) (p return r0, r1 } -// MockDBTX_Exec_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Exec' -type MockDBTX_Exec_Call struct { +// MockDBTX_ExecContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExecContext' +type MockDBTX_ExecContext_Call struct { *mock.Call } -// Exec is a helper method to define mock.On call +// ExecContext is a helper method to define mock.On call // - _a0 context.Context // - _a1 string // - _a2 ...interface{} -func (_e *MockDBTX_Expecter) Exec(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Exec_Call { - return &MockDBTX_Exec_Call{Call: _e.mock.On("Exec", +func (_e *MockDBTX_Expecter) ExecContext(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_ExecContext_Call { + return &MockDBTX_ExecContext_Call{Call: _e.mock.On("ExecContext", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_Exec_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Exec_Call { +func (_c *MockDBTX_ExecContext_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_ExecContext_Call { _c.Call.Run(func(args mock.Arguments) { variadicArgs := make([]interface{}, len(args)-2) for i, a := range args[2:] { @@ -82,37 +82,96 @@ func (_c *MockDBTX_Exec_Call) Run(run func(_a0 context.Context, _a1 string, _a2 return _c } -func (_c *MockDBTX_Exec_Call) Return(_a0 pgconn.CommandTag, _a1 error) *MockDBTX_Exec_Call { +func (_c *MockDBTX_ExecContext_Call) Return(_a0 sql.Result, _a1 error) *MockDBTX_ExecContext_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockDBTX_ExecContext_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (sql.Result, error)) *MockDBTX_ExecContext_Call { + _c.Call.Return(run) + return _c +} + +// PrepareContext provides a mock function with given fields: _a0, _a1 +func (_m *MockDBTX) PrepareContext(_a0 context.Context, _a1 string) (*sql.Stmt, error) { + ret := _m.Called(_a0, _a1) + + if len(ret) == 0 { + panic("no return value specified for PrepareContext") + } + + var r0 *sql.Stmt + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*sql.Stmt, error)); ok { + return rf(_a0, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *sql.Stmt); ok { + r0 = rf(_a0, _a1) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*sql.Stmt) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(_a0, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockDBTX_PrepareContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PrepareContext' +type MockDBTX_PrepareContext_Call struct { + *mock.Call +} + +// PrepareContext is a helper method to define mock.On call +// - _a0 context.Context +// - _a1 string +func (_e *MockDBTX_Expecter) PrepareContext(_a0 interface{}, _a1 interface{}) *MockDBTX_PrepareContext_Call { + return &MockDBTX_PrepareContext_Call{Call: _e.mock.On("PrepareContext", _a0, _a1)} +} + +func (_c *MockDBTX_PrepareContext_Call) Run(run func(_a0 context.Context, _a1 string)) *MockDBTX_PrepareContext_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *MockDBTX_PrepareContext_Call) Return(_a0 *sql.Stmt, _a1 error) *MockDBTX_PrepareContext_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockDBTX_Exec_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)) *MockDBTX_Exec_Call { +func (_c *MockDBTX_PrepareContext_Call) RunAndReturn(run func(context.Context, string) (*sql.Stmt, error)) *MockDBTX_PrepareContext_Call { _c.Call.Return(run) return _c } -// Query provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgx.Rows, error) { +// QueryContext provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) QueryContext(_a0 context.Context, _a1 string, _a2 ...interface{}) (*sql.Rows, error) { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) ret := _m.Called(_ca...) if len(ret) == 0 { - panic("no return value specified for Query") + panic("no return value specified for QueryContext") } - var r0 pgx.Rows + var r0 *sql.Rows var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (*sql.Rows, error)); ok { return rf(_a0, _a1, _a2...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) *sql.Rows); ok { r0 = rf(_a0, _a1, _a2...) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Rows) + r0 = ret.Get(0).(*sql.Rows) } } @@ -125,21 +184,21 @@ func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) ( return r0, r1 } -// MockDBTX_Query_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Query' -type MockDBTX_Query_Call struct { +// MockDBTX_QueryContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryContext' +type MockDBTX_QueryContext_Call struct { *mock.Call } -// Query is a helper method to define mock.On call +// QueryContext is a helper method to define mock.On call // - _a0 context.Context // - _a1 string // - _a2 ...interface{} -func (_e *MockDBTX_Expecter) Query(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Query_Call { - return &MockDBTX_Query_Call{Call: _e.mock.On("Query", +func (_e *MockDBTX_Expecter) QueryContext(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_QueryContext_Call { + return &MockDBTX_QueryContext_Call{Call: _e.mock.On("QueryContext", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_Query_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Query_Call { +func (_c *MockDBTX_QueryContext_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_QueryContext_Call { _c.Call.Run(func(args mock.Arguments) { variadicArgs := make([]interface{}, len(args)-2) for i, a := range args[2:] { @@ -152,54 +211,54 @@ func (_c *MockDBTX_Query_Call) Run(run func(_a0 context.Context, _a1 string, _a2 return _c } -func (_c *MockDBTX_Query_Call) Return(_a0 pgx.Rows, _a1 error) *MockDBTX_Query_Call { +func (_c *MockDBTX_QueryContext_Call) Return(_a0 *sql.Rows, _a1 error) *MockDBTX_QueryContext_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockDBTX_Query_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgx.Rows, error)) *MockDBTX_Query_Call { +func (_c *MockDBTX_QueryContext_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (*sql.Rows, error)) *MockDBTX_QueryContext_Call { _c.Call.Return(run) return _c } -// QueryRow provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) QueryRow(_a0 context.Context, _a1 string, _a2 ...interface{}) pgx.Row { +// QueryRowContext provides a mock function with given fields: _a0, _a1, _a2 +func (_m *MockDBTX) QueryRowContext(_a0 context.Context, _a1 string, _a2 ...interface{}) *sql.Row { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) ret := _m.Called(_ca...) if len(ret) == 0 { - panic("no return value specified for QueryRow") + panic("no return value specified for QueryRowContext") } - var r0 pgx.Row - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { + var r0 *sql.Row + if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) *sql.Row); ok { r0 = rf(_a0, _a1, _a2...) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(pgx.Row) + r0 = ret.Get(0).(*sql.Row) } } return r0 } -// MockDBTX_QueryRow_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryRow' -type MockDBTX_QueryRow_Call struct { +// MockDBTX_QueryRowContext_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryRowContext' +type MockDBTX_QueryRowContext_Call struct { *mock.Call } -// QueryRow is a helper method to define mock.On call +// QueryRowContext is a helper method to define mock.On call // - _a0 context.Context // - _a1 string // - _a2 ...interface{} -func (_e *MockDBTX_Expecter) QueryRow(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_QueryRow_Call { - return &MockDBTX_QueryRow_Call{Call: _e.mock.On("QueryRow", +func (_e *MockDBTX_Expecter) QueryRowContext(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_QueryRowContext_Call { + return &MockDBTX_QueryRowContext_Call{Call: _e.mock.On("QueryRowContext", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_QueryRow_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_QueryRow_Call { +func (_c *MockDBTX_QueryRowContext_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_QueryRowContext_Call { _c.Call.Run(func(args mock.Arguments) { variadicArgs := make([]interface{}, len(args)-2) for i, a := range args[2:] { @@ -212,12 +271,12 @@ func (_c *MockDBTX_QueryRow_Call) Run(run func(_a0 context.Context, _a1 string, return _c } -func (_c *MockDBTX_QueryRow_Call) Return(_a0 pgx.Row) *MockDBTX_QueryRow_Call { +func (_c *MockDBTX_QueryRowContext_Call) Return(_a0 *sql.Row) *MockDBTX_QueryRowContext_Call { _c.Call.Return(_a0) return _c } -func (_c *MockDBTX_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...interface{}) pgx.Row) *MockDBTX_QueryRow_Call { +func (_c *MockDBTX_QueryRowContext_Call) RunAndReturn(run func(context.Context, string, ...interface{}) *sql.Row) *MockDBTX_QueryRowContext_Call { _c.Call.Return(run) return _c } diff --git a/backend/gen/go/db/dbschemas/postgresql/system.sql.go b/backend/gen/go/db/dbschemas/postgresql/system.sql.go index 44de095769..9b7c6f15e8 100644 --- a/backend/gen/go/db/dbschemas/postgresql/system.sql.go +++ b/backend/gen/go/db/dbschemas/postgresql/system.sql.go @@ -7,8 +7,9 @@ package pg_queries import ( "context" + "database/sql" - "github.com/jackc/pgx/v5/pgtype" + "github.com/lib/pq" ) const getCustomFunctionsBySchemaAndTables = `-- name: GetCustomFunctionsBySchemaAndTables :many @@ -78,7 +79,7 @@ type GetCustomFunctionsBySchemaAndTablesRow struct { } func (q *Queries) GetCustomFunctionsBySchemaAndTables(ctx context.Context, db DBTX, arg *GetCustomFunctionsBySchemaAndTablesParams) ([]*GetCustomFunctionsBySchemaAndTablesRow, error) { - rows, err := db.Query(ctx, getCustomFunctionsBySchemaAndTables, arg.Schema, arg.Tables) + rows, err := db.QueryContext(ctx, getCustomFunctionsBySchemaAndTables, arg.Schema, pq.Array(arg.Tables)) if err != nil { return nil, err } @@ -96,6 +97,9 @@ func (q *Queries) GetCustomFunctionsBySchemaAndTables(ctx context.Context, db DB } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -174,7 +178,7 @@ type GetCustomSequencesBySchemaAndTablesRow struct { } func (q *Queries) GetCustomSequencesBySchemaAndTables(ctx context.Context, db DBTX, arg *GetCustomSequencesBySchemaAndTablesParams) ([]*GetCustomSequencesBySchemaAndTablesRow, error) { - rows, err := db.Query(ctx, getCustomSequencesBySchemaAndTables, arg.Schema, arg.Tables) + rows, err := db.QueryContext(ctx, getCustomSequencesBySchemaAndTables, arg.Schema, pq.Array(arg.Tables)) if err != nil { return nil, err } @@ -194,6 +198,9 @@ func (q *Queries) GetCustomSequencesBySchemaAndTables(ctx context.Context, db DB } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -225,7 +232,7 @@ type GetCustomTriggersBySchemaAndTablesRow struct { } func (q *Queries) GetCustomTriggersBySchemaAndTables(ctx context.Context, db DBTX, schematables []string) ([]*GetCustomTriggersBySchemaAndTablesRow, error) { - rows, err := db.Query(ctx, getCustomTriggersBySchemaAndTables, schematables) + rows, err := db.QueryContext(ctx, getCustomTriggersBySchemaAndTables, pq.Array(schematables)) if err != nil { return nil, err } @@ -243,6 +250,9 @@ func (q *Queries) GetCustomTriggersBySchemaAndTables(ctx context.Context, db DBT } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -393,7 +403,7 @@ type GetDataTypesBySchemaAndTablesRow struct { } func (q *Queries) GetDataTypesBySchemaAndTables(ctx context.Context, db DBTX, arg *GetDataTypesBySchemaAndTablesParams) ([]*GetDataTypesBySchemaAndTablesRow, error) { - rows, err := db.Query(ctx, getDataTypesBySchemaAndTables, arg.Schema, arg.Tables) + rows, err := db.QueryContext(ctx, getDataTypesBySchemaAndTables, arg.Schema, pq.Array(arg.Tables)) if err != nil { return nil, err } @@ -411,6 +421,9 @@ func (q *Queries) GetDataTypesBySchemaAndTables(ctx context.Context, db DBTX, ar } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -558,18 +571,18 @@ type GetDatabaseSchemaRow struct { OrdinalPosition int16 GeneratedType string IdentityGeneration string - TableOid pgtype.Uint32 + TableOid interface{} SequenceType string - SeqIncrementBy *int64 - SeqMinValue *int64 - SeqMaxValue *int64 - SeqStartValue *int64 - SeqCacheValue *int64 - SeqCycleOption *bool + SeqIncrementBy sql.NullInt64 + SeqMinValue sql.NullInt64 + SeqMaxValue sql.NullInt64 + SeqStartValue sql.NullInt64 + SeqCacheValue sql.NullInt64 + SeqCycleOption sql.NullBool } func (q *Queries) GetDatabaseSchema(ctx context.Context, db DBTX) ([]*GetDatabaseSchemaRow, error) { - rows, err := db.Query(ctx, getDatabaseSchema) + rows, err := db.QueryContext(ctx, getDatabaseSchema) if err != nil { return nil, err } @@ -603,6 +616,9 @@ func (q *Queries) GetDatabaseSchema(ctx context.Context, db DBTX) ([]*GetDatabas } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -750,18 +766,18 @@ type GetDatabaseTableSchemasBySchemasAndTablesRow struct { OrdinalPosition int16 GeneratedType string IdentityGeneration string - TableOid pgtype.Uint32 + TableOid interface{} SequenceType string - SeqIncrementBy *int64 - SeqMinValue *int64 - SeqMaxValue *int64 - SeqStartValue *int64 - SeqCacheValue *int64 - SeqCycleOption *bool + SeqIncrementBy sql.NullInt64 + SeqMinValue sql.NullInt64 + SeqMaxValue sql.NullInt64 + SeqStartValue sql.NullInt64 + SeqCacheValue sql.NullInt64 + SeqCycleOption sql.NullBool } func (q *Queries) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, db DBTX, schematables []string) ([]*GetDatabaseTableSchemasBySchemasAndTablesRow, error) { - rows, err := db.Query(ctx, getDatabaseTableSchemasBySchemasAndTables, schematables) + rows, err := db.QueryContext(ctx, getDatabaseTableSchemasBySchemasAndTables, pq.Array(schematables)) if err != nil { return nil, err } @@ -795,6 +811,9 @@ func (q *Queries) GetDatabaseTableSchemasBySchemasAndTables(ctx context.Context, } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -832,7 +851,7 @@ type GetIndicesBySchemasAndTablesRow struct { } func (q *Queries) GetIndicesBySchemasAndTables(ctx context.Context, db DBTX, schematables []string) ([]*GetIndicesBySchemasAndTablesRow, error) { - rows, err := db.Query(ctx, getIndicesBySchemasAndTables, schematables) + rows, err := db.QueryContext(ctx, getIndicesBySchemasAndTables, pq.Array(schematables)) if err != nil { return nil, err } @@ -850,6 +869,9 @@ func (q *Queries) GetIndicesBySchemasAndTables(ctx context.Context, db DBTX, sch } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -878,7 +900,7 @@ type GetPostgresRolePermissionsRow struct { } func (q *Queries) GetPostgresRolePermissions(ctx context.Context, db DBTX) ([]*GetPostgresRolePermissionsRow, error) { - rows, err := db.Query(ctx, getPostgresRolePermissions) + rows, err := db.QueryContext(ctx, getPostgresRolePermissions) if err != nil { return nil, err } @@ -891,6 +913,9 @@ func (q *Queries) GetPostgresRolePermissions(ctx context.Context, db DBTX) ([]*G } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -979,7 +1004,7 @@ type GetTableConstraintsRow struct { } func (q *Queries) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTableConstraintsParams) ([]*GetTableConstraintsRow, error) { - rows, err := db.Query(ctx, getTableConstraints, arg.Schema, arg.Table) + rows, err := db.QueryContext(ctx, getTableConstraints, arg.Schema, arg.Table) if err != nil { return nil, err } @@ -992,17 +1017,20 @@ func (q *Queries) GetTableConstraints(ctx context.Context, db DBTX, arg *GetTabl &i.ConstraintType, &i.SchemaName, &i.TableName, - &i.ConstraintColumns, - &i.Notnullable, + pq.Array(&i.ConstraintColumns), + pq.Array(&i.Notnullable), &i.ForeignSchemaName, &i.ForeignTableName, - &i.ForeignColumnNames, + pq.Array(&i.ForeignColumnNames), &i.ConstraintDefinition, ); err != nil { return nil, err } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } @@ -1087,7 +1115,7 @@ type GetTableConstraintsBySchemaRow struct { } func (q *Queries) GetTableConstraintsBySchema(ctx context.Context, db DBTX, schema []string) ([]*GetTableConstraintsBySchemaRow, error) { - rows, err := db.Query(ctx, getTableConstraintsBySchema, schema) + rows, err := db.QueryContext(ctx, getTableConstraintsBySchema, pq.Array(schema)) if err != nil { return nil, err } @@ -1100,17 +1128,20 @@ func (q *Queries) GetTableConstraintsBySchema(ctx context.Context, db DBTX, sche &i.ConstraintType, &i.SchemaName, &i.TableName, - &i.ConstraintColumns, - &i.Notnullable, + pq.Array(&i.ConstraintColumns), + pq.Array(&i.Notnullable), &i.ForeignSchemaName, &i.ForeignTableName, - &i.ForeignColumnNames, + pq.Array(&i.ForeignColumnNames), &i.ConstraintDefinition, ); err != nil { return nil, err } items = append(items, &i) } + if err := rows.Close(); err != nil { + return nil, err + } if err := rows.Err(); err != nil { return nil, err } diff --git a/backend/internal/auth/jwt/mock_JwtValidator.go b/backend/internal/auth/jwt/mock_JwtValidator.go index 4e67b36065..c50d8ad93f 100644 --- a/backend/internal/auth/jwt/mock_JwtValidator.go +++ b/backend/internal/auth/jwt/mock_JwtValidator.go @@ -22,23 +22,23 @@ func (_m *MockJwtValidator) EXPECT() *MockJwtValidator_Expecter { } // ValidateToken provides a mock function with given fields: ctx, tokenString -func (_m *MockJwtValidator) ValidateToken(ctx context.Context, tokenString string) (interface{}, error) { +func (_m *MockJwtValidator) ValidateToken(ctx context.Context, tokenString string) (any, error) { ret := _m.Called(ctx, tokenString) if len(ret) == 0 { panic("no return value specified for ValidateToken") } - var r0 interface{} + var r0 any var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (interface{}, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) (any, error)); ok { return rf(ctx, tokenString) } - if rf, ok := ret.Get(0).(func(context.Context, string) interface{}); ok { + if rf, ok := ret.Get(0).(func(context.Context, string) any); ok { r0 = rf(ctx, tokenString) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) + r0 = ret.Get(0).(any) } } @@ -70,12 +70,12 @@ func (_c *MockJwtValidator_ValidateToken_Call) Run(run func(ctx context.Context, return _c } -func (_c *MockJwtValidator_ValidateToken_Call) Return(_a0 interface{}, _a1 error) *MockJwtValidator_ValidateToken_Call { +func (_c *MockJwtValidator_ValidateToken_Call) Return(_a0 any, _a1 error) *MockJwtValidator_ValidateToken_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockJwtValidator_ValidateToken_Call) RunAndReturn(run func(context.Context, string) (interface{}, error)) *MockJwtValidator_ValidateToken_Call { +func (_c *MockJwtValidator_ValidateToken_Call) RunAndReturn(run func(context.Context, string) (any, error)) *MockJwtValidator_ValidateToken_Call { _c.Call.Return(run) return _c } diff --git a/backend/internal/neosyncdb/mock_DBTX.go b/backend/internal/neosyncdb/mock_DBTX.go index c5c3dd2d3d..f38e4bab80 100644 --- a/backend/internal/neosyncdb/mock_DBTX.go +++ b/backend/internal/neosyncdb/mock_DBTX.go @@ -201,7 +201,7 @@ func (_c *MockDBTX_CopyFrom_Call) RunAndReturn(run func(context.Context, pgx.Ide } // Exec provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgconn.CommandTag, error) { +func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...any) (pgconn.CommandTag, error) { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) @@ -213,16 +213,16 @@ func (_m *MockDBTX) Exec(_a0 context.Context, _a1 string, _a2 ...interface{}) (p var r0 pgconn.CommandTag var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) (pgconn.CommandTag, error)); ok { return rf(_a0, _a1, _a2...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgconn.CommandTag); ok { r0 = rf(_a0, _a1, _a2...) } else { r0 = ret.Get(0).(pgconn.CommandTag) } - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, ...any) error); ok { r1 = rf(_a0, _a1, _a2...) } else { r1 = ret.Error(1) @@ -239,18 +239,18 @@ type MockDBTX_Exec_Call struct { // Exec is a helper method to define mock.On call // - _a0 context.Context // - _a1 string -// - _a2 ...interface{} +// - _a2 ...any func (_e *MockDBTX_Expecter) Exec(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Exec_Call { return &MockDBTX_Exec_Call{Call: _e.mock.On("Exec", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_Exec_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Exec_Call { +func (_c *MockDBTX_Exec_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...any)) *MockDBTX_Exec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -263,7 +263,7 @@ func (_c *MockDBTX_Exec_Call) Return(_a0 pgconn.CommandTag, _a1 error) *MockDBTX return _c } -func (_c *MockDBTX_Exec_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)) *MockDBTX_Exec_Call { +func (_c *MockDBTX_Exec_Call) RunAndReturn(run func(context.Context, string, ...any) (pgconn.CommandTag, error)) *MockDBTX_Exec_Call { _c.Call.Return(run) return _c } @@ -315,7 +315,7 @@ func (_c *MockDBTX_Ping_Call) RunAndReturn(run func(context.Context) error) *Moc } // Query provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) (pgx.Rows, error) { +func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...any) (pgx.Rows, error) { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) @@ -327,10 +327,10 @@ func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) ( var r0 pgx.Rows var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) (pgx.Rows, error)); ok { return rf(_a0, _a1, _a2...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgx.Rows); ok { r0 = rf(_a0, _a1, _a2...) } else { if ret.Get(0) != nil { @@ -338,7 +338,7 @@ func (_m *MockDBTX) Query(_a0 context.Context, _a1 string, _a2 ...interface{}) ( } } - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, ...any) error); ok { r1 = rf(_a0, _a1, _a2...) } else { r1 = ret.Error(1) @@ -355,18 +355,18 @@ type MockDBTX_Query_Call struct { // Query is a helper method to define mock.On call // - _a0 context.Context // - _a1 string -// - _a2 ...interface{} +// - _a2 ...any func (_e *MockDBTX_Expecter) Query(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_Query_Call { return &MockDBTX_Query_Call{Call: _e.mock.On("Query", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_Query_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_Query_Call { +func (_c *MockDBTX_Query_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...any)) *MockDBTX_Query_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -379,13 +379,13 @@ func (_c *MockDBTX_Query_Call) Return(_a0 pgx.Rows, _a1 error) *MockDBTX_Query_C return _c } -func (_c *MockDBTX_Query_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgx.Rows, error)) *MockDBTX_Query_Call { +func (_c *MockDBTX_Query_Call) RunAndReturn(run func(context.Context, string, ...any) (pgx.Rows, error)) *MockDBTX_Query_Call { _c.Call.Return(run) return _c } // QueryRow provides a mock function with given fields: _a0, _a1, _a2 -func (_m *MockDBTX) QueryRow(_a0 context.Context, _a1 string, _a2 ...interface{}) pgx.Row { +func (_m *MockDBTX) QueryRow(_a0 context.Context, _a1 string, _a2 ...any) pgx.Row { var _ca []interface{} _ca = append(_ca, _a0, _a1) _ca = append(_ca, _a2...) @@ -396,7 +396,7 @@ func (_m *MockDBTX) QueryRow(_a0 context.Context, _a1 string, _a2 ...interface{} } var r0 pgx.Row - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgx.Row); ok { r0 = rf(_a0, _a1, _a2...) } else { if ret.Get(0) != nil { @@ -415,18 +415,18 @@ type MockDBTX_QueryRow_Call struct { // QueryRow is a helper method to define mock.On call // - _a0 context.Context // - _a1 string -// - _a2 ...interface{} +// - _a2 ...any func (_e *MockDBTX_Expecter) QueryRow(_a0 interface{}, _a1 interface{}, _a2 ...interface{}) *MockDBTX_QueryRow_Call { return &MockDBTX_QueryRow_Call{Call: _e.mock.On("QueryRow", append([]interface{}{_a0, _a1}, _a2...)...)} } -func (_c *MockDBTX_QueryRow_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...interface{})) *MockDBTX_QueryRow_Call { +func (_c *MockDBTX_QueryRow_Call) Run(run func(_a0 context.Context, _a1 string, _a2 ...any)) *MockDBTX_QueryRow_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -439,7 +439,7 @@ func (_c *MockDBTX_QueryRow_Call) Return(_a0 pgx.Row) *MockDBTX_QueryRow_Call { return _c } -func (_c *MockDBTX_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...interface{}) pgx.Row) *MockDBTX_QueryRow_Call { +func (_c *MockDBTX_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...any) pgx.Row) *MockDBTX_QueryRow_Call { _c.Call.Return(run) return _c } diff --git a/backend/pkg/dbconnect-config/dbconnect-config.go b/backend/pkg/dbconnect-config/dbconnect-config.go index c170bfb6c6..f39c7eee3e 100644 --- a/backend/pkg/dbconnect-config/dbconnect-config.go +++ b/backend/pkg/dbconnect-config/dbconnect-config.go @@ -26,28 +26,6 @@ type GeneralDbConnectConfig struct { queryParams url.Values } -func (g *GeneralDbConnectConfig) GetDriver() string { - return g.driver -} - -func (g *GeneralDbConnectConfig) SetPort(port int32) { - g.port = &port -} -func (g *GeneralDbConnectConfig) SetHost(host string) { - g.host = host -} - -func (g *GeneralDbConnectConfig) GetPort() *int32 { - return g.port -} -func (g *GeneralDbConnectConfig) GetHost() string { - return g.host -} - -func (g *GeneralDbConnectConfig) GetUser() string { - return g.user -} - func (g *GeneralDbConnectConfig) String() string { if g.driver == postgresDriver { u := url.URL{ diff --git a/backend/pkg/dbconnect-config/dbconnect-config_test.go b/backend/pkg/dbconnect-config/dbconnect-config_test.go index c1f05f0ae9..108c9a9025 100644 --- a/backend/pkg/dbconnect-config/dbconnect-config_test.go +++ b/backend/pkg/dbconnect-config/dbconnect-config_test.go @@ -5,30 +5,8 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" ) -func Test_GeneralDbConnectionConfig_Helper_Methods(t *testing.T) { - cfg := GeneralDbConnectConfig{ - driver: "postgres", - host: "localhost", - port: ptr(int32(5432)), - database: ptr("mydb"), - user: "test-user", - pass: "test-pass", - queryParams: url.Values{"sslmode": []string{"verify"}}, - } - require.Equal(t, cfg.GetDriver(), "postgres") - require.Equal(t, cfg.GetHost(), "localhost") - require.Equal(t, *cfg.GetPort(), int32(5432)) - require.Equal(t, cfg.GetUser(), "test-user") - - cfg.SetHost("foo") - cfg.SetPort(5433) - require.Equal(t, cfg.GetHost(), "foo") - require.Equal(t, *cfg.GetPort(), int32(5433)) -} - func Test_GeneralDbConnectionConfig_String(t *testing.T) { type testcase struct { name string diff --git a/backend/pkg/mongoconnect/connector.go b/backend/pkg/mongoconnect/connector.go index 35b8a1f609..e36b6578f4 100644 --- a/backend/pkg/mongoconnect/connector.go +++ b/backend/pkg/mongoconnect/connector.go @@ -10,16 +10,9 @@ import ( mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/nucleuscloud/neosync/backend/pkg/clienttls" - "github.com/nucleuscloud/neosync/backend/pkg/sshtunnel" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" - "golang.org/x/crypto/ssh" -) - -const ( - localhost = "localhost" - randomPort = 0 ) type Interface interface { @@ -36,9 +29,6 @@ type WrappedMongoClient struct { clientMu sync.Mutex details *connstring.ConnString - // tunnel *sshtunnel.Sshtunnel - - // logger *slog.Logger } var _ DbContainer = &WrappedMongoClient{} @@ -91,7 +81,7 @@ func (c *Connector) NewFromConnectionConfig( return nil, errors.New("cc was nil, expected *mgmtv1alpha1.ConnectionConfig") } - details, err := GetConnectionDetails(cc, clienttls.UpsertClientTlsFileSingleClient, logger) + details, err := getConnectionDetails(cc, clienttls.UpsertClientTlsFileSingleClient) if err != nil { return nil, err } @@ -100,22 +90,16 @@ func (c *Connector) NewFromConnectionConfig( } type ConnectionDetails struct { - Tunnel *sshtunnel.Sshtunnel Details *connstring.ConnString } -func (c *ConnectionDetails) GetTunnel() *sshtunnel.Sshtunnel { - return c.Tunnel -} func (c *ConnectionDetails) String() string { - // todo: add tunnel support return c.Details.String() } -func GetConnectionDetails( +func getConnectionDetails( cc *mgmtv1alpha1.ConnectionConfig, handleClientTlsConfig clienttls.ClientTlsFileHandler, - logger *slog.Logger, ) (*ConnectionDetails, error) { if cc == nil { return nil, errors.New("cc was nil, expected *mgmtv1alpha1.ConnectionConfig") @@ -133,44 +117,15 @@ func GetConnectionDetails( } } tunnelCfg := mongoConfig.GetTunnel() - if tunnelCfg == nil { - connDetails, err := getGeneralDbConnectConfigFromMongo(mongoConfig) - if err != nil { - return nil, err - } - return &ConnectionDetails{ - Details: connDetails, - }, nil + if tunnelCfg != nil { + return nil, fmt.Errorf("tunneling in mongodb is not currently supported: %w", errors.ErrUnsupported) } - var destination *sshtunnel.Endpoint // todo - authmethod, err := sshtunnel.GetTunnelAuthMethodFromSshConfig(tunnelCfg.GetAuthentication()) - if err != nil { - return nil, err - } - var publickey ssh.PublicKey - if tunnelCfg.GetKnownHostPublicKey() == "" { - publickey, err = sshtunnel.ParseSshKey(tunnelCfg.GetKnownHostPublicKey()) - if err != nil { - return nil, err - } - } - tunnel := sshtunnel.New( - sshtunnel.NewEndpointWithUser(tunnelCfg.GetHost(), int(tunnelCfg.GetPort()), tunnelCfg.GetUser()), - authmethod, - destination, - sshtunnel.NewEndpoint(localhost, randomPort), - 1, - publickey, - ) connDetails, err := getGeneralDbConnectConfigFromMongo(mongoConfig) if err != nil { return nil, err } - _ = connDetails - return &ConnectionDetails{ - Tunnel: tunnel, Details: connDetails, }, nil } diff --git a/backend/pkg/sqlconnect/mock_PgPoolContainer.go b/backend/pkg/sqlconnect/mock_PgPoolContainer.go deleted file mode 100644 index ca7dc849c3..0000000000 --- a/backend/pkg/sqlconnect/mock_PgPoolContainer.go +++ /dev/null @@ -1,127 +0,0 @@ -// Code generated by mockery. DO NOT EDIT. - -package sqlconnect - -import ( - context "context" - - pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" - mock "github.com/stretchr/testify/mock" -) - -// MockPgPoolContainer is an autogenerated mock type for the PgPoolContainer type -type MockPgPoolContainer struct { - mock.Mock -} - -type MockPgPoolContainer_Expecter struct { - mock *mock.Mock -} - -func (_m *MockPgPoolContainer) EXPECT() *MockPgPoolContainer_Expecter { - return &MockPgPoolContainer_Expecter{mock: &_m.Mock} -} - -// Close provides a mock function with given fields: -func (_m *MockPgPoolContainer) Close() { - _m.Called() -} - -// MockPgPoolContainer_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type MockPgPoolContainer_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *MockPgPoolContainer_Expecter) Close() *MockPgPoolContainer_Close_Call { - return &MockPgPoolContainer_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *MockPgPoolContainer_Close_Call) Run(run func()) *MockPgPoolContainer_Close_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *MockPgPoolContainer_Close_Call) Return() *MockPgPoolContainer_Close_Call { - _c.Call.Return() - return _c -} - -func (_c *MockPgPoolContainer_Close_Call) RunAndReturn(run func()) *MockPgPoolContainer_Close_Call { - _c.Call.Return(run) - return _c -} - -// Open provides a mock function with given fields: _a0 -func (_m *MockPgPoolContainer) Open(_a0 context.Context) (pg_queries.DBTX, error) { - ret := _m.Called(_a0) - - if len(ret) == 0 { - panic("no return value specified for Open") - } - - var r0 pg_queries.DBTX - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (pg_queries.DBTX, error)); ok { - return rf(_a0) - } - if rf, ok := ret.Get(0).(func(context.Context) pg_queries.DBTX); ok { - r0 = rf(_a0) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(pg_queries.DBTX) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(_a0) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockPgPoolContainer_Open_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Open' -type MockPgPoolContainer_Open_Call struct { - *mock.Call -} - -// Open is a helper method to define mock.On call -// - _a0 context.Context -func (_e *MockPgPoolContainer_Expecter) Open(_a0 interface{}) *MockPgPoolContainer_Open_Call { - return &MockPgPoolContainer_Open_Call{Call: _e.mock.On("Open", _a0)} -} - -func (_c *MockPgPoolContainer_Open_Call) Run(run func(_a0 context.Context)) *MockPgPoolContainer_Open_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(context.Context)) - }) - return _c -} - -func (_c *MockPgPoolContainer_Open_Call) Return(_a0 pg_queries.DBTX, _a1 error) *MockPgPoolContainer_Open_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockPgPoolContainer_Open_Call) RunAndReturn(run func(context.Context) (pg_queries.DBTX, error)) *MockPgPoolContainer_Open_Call { - _c.Call.Return(run) - return _c -} - -// NewMockPgPoolContainer creates a new instance of MockPgPoolContainer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockPgPoolContainer(t interface { - mock.TestingT - Cleanup(func()) -}) *MockPgPoolContainer { - mock := &MockPgPoolContainer{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/backend/pkg/sqlconnect/mock_SqlConnector.go b/backend/pkg/sqlconnect/mock_SqlConnector.go index eec7c3ec2e..7c98534a68 100644 --- a/backend/pkg/sqlconnect/mock_SqlConnector.go +++ b/backend/pkg/sqlconnect/mock_SqlConnector.go @@ -82,66 +82,6 @@ func (_c *MockSqlConnector_NewDbFromConnectionConfig_Call) RunAndReturn(run func return _c } -// NewPgPoolFromConnectionConfig provides a mock function with given fields: pgconfig, connectionTimeout, logger -func (_m *MockSqlConnector) NewPgPoolFromConnectionConfig(pgconfig *mgmtv1alpha1.PostgresConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (PgPoolContainer, error) { - ret := _m.Called(pgconfig, connectionTimeout, logger) - - if len(ret) == 0 { - panic("no return value specified for NewPgPoolFromConnectionConfig") - } - - var r0 PgPoolContainer - var r1 error - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.PostgresConnectionConfig, *uint32, *slog.Logger) (PgPoolContainer, error)); ok { - return rf(pgconfig, connectionTimeout, logger) - } - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.PostgresConnectionConfig, *uint32, *slog.Logger) PgPoolContainer); ok { - r0 = rf(pgconfig, connectionTimeout, logger) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(PgPoolContainer) - } - } - - if rf, ok := ret.Get(1).(func(*mgmtv1alpha1.PostgresConnectionConfig, *uint32, *slog.Logger) error); ok { - r1 = rf(pgconfig, connectionTimeout, logger) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockSqlConnector_NewPgPoolFromConnectionConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'NewPgPoolFromConnectionConfig' -type MockSqlConnector_NewPgPoolFromConnectionConfig_Call struct { - *mock.Call -} - -// NewPgPoolFromConnectionConfig is a helper method to define mock.On call -// - pgconfig *mgmtv1alpha1.PostgresConnectionConfig -// - connectionTimeout *uint32 -// - logger *slog.Logger -func (_e *MockSqlConnector_Expecter) NewPgPoolFromConnectionConfig(pgconfig interface{}, connectionTimeout interface{}, logger interface{}) *MockSqlConnector_NewPgPoolFromConnectionConfig_Call { - return &MockSqlConnector_NewPgPoolFromConnectionConfig_Call{Call: _e.mock.On("NewPgPoolFromConnectionConfig", pgconfig, connectionTimeout, logger)} -} - -func (_c *MockSqlConnector_NewPgPoolFromConnectionConfig_Call) Run(run func(pgconfig *mgmtv1alpha1.PostgresConnectionConfig, connectionTimeout *uint32, logger *slog.Logger)) *MockSqlConnector_NewPgPoolFromConnectionConfig_Call { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*mgmtv1alpha1.PostgresConnectionConfig), args[1].(*uint32), args[2].(*slog.Logger)) - }) - return _c -} - -func (_c *MockSqlConnector_NewPgPoolFromConnectionConfig_Call) Return(_a0 PgPoolContainer, _a1 error) *MockSqlConnector_NewPgPoolFromConnectionConfig_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockSqlConnector_NewPgPoolFromConnectionConfig_Call) RunAndReturn(run func(*mgmtv1alpha1.PostgresConnectionConfig, *uint32, *slog.Logger) (PgPoolContainer, error)) *MockSqlConnector_NewPgPoolFromConnectionConfig_Call { - _c.Call.Return(run) - return _c -} - // NewMockSqlConnector creates a new instance of MockSqlConnector. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. func NewMockSqlConnector(t interface { diff --git a/backend/pkg/sqlconnect/pgpool.go b/backend/pkg/sqlconnect/pgpool.go deleted file mode 100644 index 9914c55bbd..0000000000 --- a/backend/pkg/sqlconnect/pgpool.go +++ /dev/null @@ -1,123 +0,0 @@ -package sqlconnect - -import ( - context "context" - "fmt" - "log/slog" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" - "github.com/jackc/pgx/v5/tracelog" - pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" - pgxslog "github.com/nucleuscloud/neosync/backend/internal/pgx-slog" - "github.com/nucleuscloud/neosync/backend/pkg/sshtunnel" -) - -// interface used by SqlConnector to abstract away the opening and closing of a Pgxpool that includes tunneling -type PgPoolContainer interface { - Open(context.Context) (pg_queries.DBTX, error) - Close() -} - -type PgPool struct { - pool *pgxpool.Pool - - details *ConnectionDetails - - // instance of the created tunnel - tunnel *sshtunnel.Sshtunnel - - dsn string - - logger *slog.Logger -} - -func newPgPool(details *ConnectionDetails, logger *slog.Logger) *PgPool { - return &PgPool{ - details: details, - logger: logger, - } -} - -func (s *PgPool) GetDsn() string { - return s.dsn -} - -func (s *PgPool) Open(ctx context.Context) (pg_queries.DBTX, error) { - if s.details.Tunnel != nil { - ready, err := s.details.Tunnel.Start(s.logger) - if err != nil { - return nil, err - } - <-ready - - _, localport := s.details.Tunnel.GetLocalHostPort() - newPort := int32(localport) //nolint:gosec // Ignoring for now - s.details.GeneralDbConnectConfig.SetPort(newPort) - dsn := s.details.GeneralDbConnectConfig.String() - - config, err := pgxpool.ParseConfig(dsn) - if err != nil { - return nil, fmt.Errorf("unable to parse dsn into pg config: %w", err) - } - config.ConnConfig.Tracer = &tracelog.TraceLog{ - Logger: pgxslog.NewLogger(s.logger, pgxslog.GetShouldOmitArgs()), - LogLevel: pgxslog.GetDatabaseLogLevel(), - } - config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeExec - - // set max number of connections. - if s.details.MaxConnectionLimit != nil { - config.MaxConns = *s.details.MaxConnectionLimit - } - - db, err := pgxpool.NewWithConfig(ctx, config) - if err != nil { - s.details.Tunnel.Close() - return nil, err - } - s.dsn = dsn - s.pool = db - s.tunnel = s.details.Tunnel - return db, nil - } - - dsn := s.details.GeneralDbConnectConfig.String() - config, err := pgxpool.ParseConfig(dsn) - if err != nil { - return nil, err - } - config.ConnConfig.Tracer = &tracelog.TraceLog{ - Logger: pgxslog.NewLogger(s.logger, pgxslog.GetShouldOmitArgs()), - LogLevel: pgxslog.GetDatabaseLogLevel(), - } - config.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeExec - - // set max number of connections. - if s.details.MaxConnectionLimit != nil { - config.MaxConns = *s.details.MaxConnectionLimit - } - - db, err := pgxpool.NewWithConfig(ctx, config) - if err != nil { - return nil, err - } - s.pool = db - s.dsn = dsn - return db, nil -} - -func (s *PgPool) Close() { - if s.pool == nil { - return - } - s.dsn = "" - db := s.pool - s.pool = nil - db.Close() - if s.tunnel != nil { - tunnel := s.tunnel - s.tunnel = nil - tunnel.Close() - } -} diff --git a/backend/pkg/sqlconnect/sql-connector.go b/backend/pkg/sqlconnect/sql-connector.go index 6b869cea14..51cbb21fcd 100644 --- a/backend/pkg/sqlconnect/sql-connector.go +++ b/backend/pkg/sqlconnect/sql-connector.go @@ -3,19 +3,30 @@ package sqlconnect import ( "context" "database/sql" + "database/sql/driver" "errors" "fmt" "log/slog" + "sync" + "time" mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" "github.com/nucleuscloud/neosync/backend/pkg/clienttls" dbconnectconfig "github.com/nucleuscloud/neosync/backend/pkg/dbconnect-config" - "github.com/nucleuscloud/neosync/backend/pkg/sshtunnel" + tun "github.com/nucleuscloud/neosync/internal/sshtunnel" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/mssqltunconnector" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/mysqltunconnector" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/postgrestunconnector" "golang.org/x/crypto/ssh" ) +// interface used by SqlConnector to abstract away the opening and closing of a sqldb that includes tunnelingff +type SqlDbContainer interface { + Open() (SqlDBTX, error) + Close() error +} + type SqlDBTX interface { mysql_queries.DBTX @@ -23,305 +34,313 @@ type SqlDBTX interface { BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) } -// Allows instantiating a sql db or pg pool container that includes SSH tunneling if the config requires it type SqlConnector interface { NewDbFromConnectionConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (SqlDbContainer, error) - NewPgPoolFromConnectionConfig(pgconfig *mgmtv1alpha1.PostgresConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (PgPoolContainer, error) } type SqlOpenConnector struct{} -func (rc *SqlOpenConnector) NewDbFromConnectionConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (SqlDbContainer, error) { - if connectionConfig == nil { +func (rc *SqlOpenConnector) NewDbFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (SqlDbContainer, error) { + if cc == nil { return nil, errors.New("connectionConfig was nil, expected *mgmtv1alpha1.ConnectionConfig") } - details, err := GetConnectionDetails(connectionConfig, connectionTimeout, clienttls.UpsertCLientTlsFiles, logger) - if err != nil { - return nil, err - } - - return newSqlDb(details, logger), nil -} - -func (rc *SqlOpenConnector) NewPgPoolFromConnectionConfig(pgconfig *mgmtv1alpha1.PostgresConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (PgPoolContainer, error) { - if pgconfig == nil { - return nil, errors.New("pgconfig was nil, expected *mgmtv1alpha1.PostgresConnectionConfig") - } - details, err := GetConnectionDetails(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: pgconfig, - }, - }, connectionTimeout, clienttls.UpsertCLientTlsFiles, logger) - if err != nil { - return nil, err - } - return newPgPool(details, logger), nil -} - -type ConnectionDetails struct { - dbconnectconfig.GeneralDbConnectConfig - MaxConnectionLimit *int32 - - Tunnel *sshtunnel.Sshtunnel -} + dbconnopts := getConnectionOptsFromConnectionConfig(cc) -func (c *ConnectionDetails) GetTunnel() *sshtunnel.Sshtunnel { - return c.Tunnel -} - -func (c *ConnectionDetails) String() string { - if c.Tunnel != nil { - // todo: would be great to check if tunnel has been started... - localhost, port := c.Tunnel.GetLocalHostPort() - c.GeneralDbConnectConfig.SetHost(localhost) - c.GeneralDbConnectConfig.SetPort(int32(port)) //nolint:gosec // Ignoring for now - } - return c.GeneralDbConnectConfig.String() -} - -type ClientCertConfig struct { - RootCert *string - - ClientCert *string - ClientKey *string -} - -const ( - localhost = "localhost" - randomPort = 0 -) - -// Method for retrieving connection details, including tunneling information. -// Only use if requiring direct access to the SSH Tunnel, otherwise the SqlConnector should be used instead. -func GetConnectionDetails( - c *mgmtv1alpha1.ConnectionConfig, - connectionTimeout *uint32, - handleClientTlsConfig clienttls.ClientTlsFileHandler, - logger *slog.Logger, -) (*ConnectionDetails, error) { - if c == nil { - return nil, errors.New("connection config was nil, expected *mgmtv1alpha1.ConnectionConfig") - } - switch config := c.Config.(type) { + switch config := cc.GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: - var maxConnLimit *int32 - if config.PgConfig.ConnectionOptions != nil { - maxConnLimit = config.PgConfig.ConnectionOptions.MaxConnectionLimit - } if config.PgConfig.GetClientTls() != nil { - _, err := handleClientTlsConfig(config.PgConfig.GetClientTls()) + _, err := clienttls.UpsertCLientTlsFiles(config.PgConfig.GetClientTls()) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to upsert client tls files: %w", err) } } - if config.PgConfig.Tunnel != nil { - destination, err := getEndpointFromPgConnectionConfig(config) - if err != nil { - return nil, err - } - authmethod, err := sshtunnel.GetTunnelAuthMethodFromSshConfig(config.PgConfig.GetTunnel().GetAuthentication()) - if err != nil { - return nil, err - } - var publickey ssh.PublicKey - if config.PgConfig.Tunnel.KnownHostPublicKey != nil { - publickey, err = sshtunnel.ParseSshKey(*config.PgConfig.Tunnel.KnownHostPublicKey) - if err != nil { - return nil, err - } - } - tunnel := sshtunnel.New( - sshtunnel.NewEndpointWithUser(config.PgConfig.Tunnel.GetHost(), int(config.PgConfig.Tunnel.GetPort()), config.PgConfig.Tunnel.GetUser()), - authmethod, - destination, - sshtunnel.NewEndpoint(localhost, randomPort), - 1, - publickey, - ) - connDetails, err := dbconnectconfig.NewFromPostgresConnection(config, connectionTimeout) - if err != nil { - return nil, err - } - portValue := int32(randomPort) - connDetails.SetHost(localhost) - connDetails.SetPort(portValue) - return &ConnectionDetails{ - Tunnel: tunnel, - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil - } - connDetails, err := dbconnectconfig.NewFromPostgresConnection(config, connectionTimeout) if err != nil { return nil, err } - return &ConnectionDetails{ - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - var maxConnLimit *int32 - if config.MysqlConfig.ConnectionOptions != nil { - maxConnLimit = config.MysqlConfig.ConnectionOptions.MaxConnectionLimit - } - if config.MysqlConfig.Tunnel != nil { - destination, err := getEndpointFromMysqlConnectionConfig(config) - if err != nil { - return nil, err - } - authmethod, err := sshtunnel.GetTunnelAuthMethodFromSshConfig(config.MysqlConfig.Tunnel.Authentication) - if err != nil { - return nil, err - } - var publickey ssh.PublicKey - if config.MysqlConfig.Tunnel.KnownHostPublicKey != nil { - publickey, err = sshtunnel.ParseSshKey(*config.MysqlConfig.Tunnel.KnownHostPublicKey) - if err != nil { - return nil, err - } - } - tunnel := sshtunnel.New( - sshtunnel.NewEndpointWithUser(config.MysqlConfig.Tunnel.GetHost(), int(config.MysqlConfig.Tunnel.GetPort()), config.MysqlConfig.Tunnel.GetUser()), - authmethod, - destination, - sshtunnel.NewEndpoint(localhost, randomPort), - 1, - publickey, - ) - - connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout) - if err != nil { - return nil, err - } + dsn := connDetails.String() - portValue := int32(randomPort) - connDetails.SetHost(localhost) - connDetails.SetPort(portValue) - return &ConnectionDetails{ - Tunnel: tunnel, - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil + if config.PgConfig.GetTunnel() != nil { + return newStdlibConnectorContainer( + getTunnelConnectorFn( + config.PgConfig.GetTunnel(), + func(dialer tun.Dialer) (driver.Connector, func(), error) { + return postgrestunconnector.New(dialer, dsn) + }, + logger, + ), + dbconnopts, + ), nil + } else { + return newStdlibContainer("pgx", dsn, dbconnopts), nil } - + case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: connDetails, err := dbconnectconfig.NewFromMysqlConnection(config, connectionTimeout) if err != nil { return nil, err } - return &ConnectionDetails{ - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil - case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - var maxConnLimit *int32 - if config.MssqlConfig.GetConnectionOptions() != nil { - maxConnLimit = config.MssqlConfig.GetConnectionOptions().MaxConnectionLimit - } - if config.MssqlConfig.GetTunnel() != nil { - destination, err := getEndpointFromMssqlConnectionConfig(config) - if err != nil { - return nil, fmt.Errorf("unable to retrieve tunnel endpoint for mssql: %w", err) - } - authmethod, err := sshtunnel.GetTunnelAuthMethodFromSshConfig(config.MssqlConfig.GetTunnel().GetAuthentication()) - if err != nil { - return nil, fmt.Errorf("unable to compile auth method for ssh tunneling for mssql: %w", err) - } - var publickey ssh.PublicKey - if config.MssqlConfig.GetTunnel().GetKnownHostPublicKey() != "" { - publickey, err = sshtunnel.ParseSshKey(config.MssqlConfig.GetTunnel().GetKnownHostPublicKey()) - if err != nil { - return nil, fmt.Errorf("unable to parse provided known host public key for mssql tunnel: %w", err) - } - } - tunnel := sshtunnel.New( - sshtunnel.NewEndpointWithUser(config.MssqlConfig.GetTunnel().GetHost(), int(config.MssqlConfig.GetTunnel().GetPort()), config.MssqlConfig.GetTunnel().GetUser()), - authmethod, - destination, - sshtunnel.NewEndpoint(localhost, randomPort), - 1, - publickey, - ) - - connDetails, err := dbconnectconfig.NewFromMssqlConnection(config, connectionTimeout) - if err != nil { - return nil, fmt.Errorf("unable to compile connection details for mssql tunnel connection: %w", err) - } + dsn := connDetails.String() - portValue := int32(randomPort) - connDetails.SetHost(localhost) - connDetails.SetPort(portValue) - return &ConnectionDetails{ - Tunnel: tunnel, - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil + if config.MysqlConfig.GetTunnel() != nil { + return newStdlibConnectorContainer( + getTunnelConnectorFn( + config.MysqlConfig.GetTunnel(), + func(dialer tun.Dialer) (driver.Connector, func(), error) { + return mysqltunconnector.New(dialer, dsn) + }, + logger, + ), + dbconnopts, + ), nil } + return newStdlibContainer("mysql", dsn, dbconnopts), nil + case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: connDetails, err := dbconnectconfig.NewFromMssqlConnection(config, connectionTimeout) if err != nil { - return nil, fmt.Errorf("unable to compile connection details for mssql connection: %w", err) + return nil, err + } + dsn := connDetails.String() + + if config.MssqlConfig.GetTunnel() != nil { + return newStdlibConnectorContainer( + getTunnelConnectorFn( + config.MssqlConfig.GetTunnel(), + func(dialer tun.Dialer) (driver.Connector, func(), error) { + return mssqltunconnector.New(dialer, dsn) + }, + logger, + ), + dbconnopts, + ), nil } - return &ConnectionDetails{ - GeneralDbConnectConfig: *connDetails, - MaxConnectionLimit: maxConnLimit, - }, nil + return newStdlibContainer("sqlserver", dsn, dbconnopts), nil default: - return nil, nucleuserrors.NewNotImplemented(fmt.Sprintf("this connection config (%T) is not currently supported", config)) + return nil, fmt.Errorf("unsupported connection: %T", config) } } -func getEndpointFromPgConnectionConfig(config *mgmtv1alpha1.ConnectionConfig_PgConfig) (*sshtunnel.Endpoint, error) { - switch cc := config.PgConfig.ConnectionConfig.(type) { - case *mgmtv1alpha1.PostgresConnectionConfig_Connection: - return sshtunnel.NewEndpointWithUser(cc.Connection.Host, int(cc.Connection.Port), cc.Connection.User), nil - case *mgmtv1alpha1.PostgresConnectionConfig_Url: - details, err := dbconnectconfig.NewFromPostgresConnection(config, nil) +func getTunnelConnectorFn( + tunnel *mgmtv1alpha1.SSHTunnel, + getConnector func(dialer tun.Dialer) (driver.Connector, func(), error), + logger *slog.Logger, +) func() (driver.Connector, func(), error) { + return func() (driver.Connector, func(), error) { + cfg, err := getTunnelConfig(tunnel) if err != nil { - return nil, err + return nil, nil, fmt.Errorf("unable to construct ssh tunnel config: %w", err) + } + logger.Debug("constructed tunnel config") + dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig) + conn, cleanup, err := getConnector(dialer) + if err != nil { + return nil, nil, fmt.Errorf("unable to build db connector: %w", err) } - port := 0 - if details.GetPort() != nil { - port = int(*details.GetPort()) + logger.Debug("built database connector with ssh dialer") + wrappedCleanup := func() { + logger.Debug("cleaning up tunnel connector") + cleanup() + logger.Debug("connector cleanup completed") + if err := dialer.Close(); err != nil { + logger.Error(fmt.Errorf("encountered error when closing ssh dialer: %w", err).Error()) + } + logger.Debug("tunnel connector cleanup completed") } - return sshtunnel.NewEndpointWithUser(details.GetHost(), port, details.GetUser()), nil + return conn, wrappedCleanup, nil + } +} + +func getConnectionOptsFromConnectionConfig(cc *mgmtv1alpha1.ConnectionConfig) *DbConnectionOptions { + switch config := cc.GetConfig().(type) { + case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: + return sqlConnOptsToDbConnOpts(config.MysqlConfig.GetConnectionOptions()) + case *mgmtv1alpha1.ConnectionConfig_PgConfig: + return sqlConnOptsToDbConnOpts(config.PgConfig.GetConnectionOptions()) + case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: + return sqlConnOptsToDbConnOpts(config.MssqlConfig.GetConnectionOptions()) default: - return nil, nucleuserrors.NewBadRequest("must provide valid postgres connection") + return sqlConnOptsToDbConnOpts(&mgmtv1alpha1.SqlConnectionOptions{}) + } +} + +func sqlConnOptsToDbConnOpts(co *mgmtv1alpha1.SqlConnectionOptions) *DbConnectionOptions { + if co == nil { + co = &mgmtv1alpha1.SqlConnectionOptions{} } + return &DbConnectionOptions{ + MaxOpenConns: convertInt32PtrToIntPtr(co.MaxConnectionLimit), + } +} + +func convertInt32PtrToIntPtr(input *int32) *int { + if input == nil { + return nil + } + value := int(*input) + return &value +} + +type tunnelConfig struct { + Addr string + ClientConfig *ssh.ClientConfig } -func getEndpointFromMysqlConnectionConfig(config *mgmtv1alpha1.ConnectionConfig_MysqlConfig) (*sshtunnel.Endpoint, error) { - switch cc := config.MysqlConfig.ConnectionConfig.(type) { - case *mgmtv1alpha1.MysqlConnectionConfig_Connection: - return sshtunnel.NewEndpointWithUser(cc.Connection.Host, int(cc.Connection.Port), cc.Connection.User), nil - case *mgmtv1alpha1.MysqlConnectionConfig_Url: - details, err := dbconnectconfig.NewFromMysqlConnection(config, nil) +func getTunnelConfig(tunnel *mgmtv1alpha1.SSHTunnel) (*tunnelConfig, error) { + var hostcallback ssh.HostKeyCallback + if tunnel.GetKnownHostPublicKey() != "" { + publickey, err := tun.ParseSshKey(tunnel.GetKnownHostPublicKey()) if err != nil { - return nil, err - } - port := 0 - if details.GetPort() != nil { - port = int(*details.GetPort()) + return nil, fmt.Errorf("unable to parse ssh known host public key: %w", err) } - return sshtunnel.NewEndpointWithUser(details.GetHost(), port, details.GetUser()), nil - default: - return nil, nucleuserrors.NewBadRequest("must provide valid mysql connection") + hostcallback = ssh.FixedHostKey(publickey) + } else { + hostcallback = ssh.InsecureIgnoreHostKey() //nolint:gosec // the user has chosen to not provide a known host public key } + authmethod, err := tun.GetTunnelAuthMethodFromSshConfig(tunnel.GetAuthentication()) + if err != nil { + return nil, fmt.Errorf("unable to parse ssh auth method: %w", err) + } + + authmethods := []ssh.AuthMethod{} + if authmethod != nil { + authmethods = append(authmethods, authmethod) + } + + return &tunnelConfig{ + Addr: getSshAddr(tunnel), + ClientConfig: &ssh.ClientConfig{ + User: tunnel.GetUser(), + Auth: authmethods, + HostKeyCallback: hostcallback, + Timeout: 10 * time.Second, // todo: make configurable + }, + }, nil } -func getEndpointFromMssqlConnectionConfig(config *mgmtv1alpha1.ConnectionConfig_MssqlConfig) (*sshtunnel.Endpoint, error) { - switch cc := config.MssqlConfig.GetConnectionConfig().(type) { - case *mgmtv1alpha1.MssqlConnectionConfig_Url: - details, err := dbconnectconfig.NewFromMssqlConnection(config, nil) - if err != nil { - return nil, err +func getSshAddr(tunnel *mgmtv1alpha1.SSHTunnel) string { + host := tunnel.GetHost() + port := tunnel.GetPort() + if port > 0 { + return fmt.Sprintf("%s:%d", host, port) + } + return host +} + +func newStdlibConnectorContainer(getter func() (driver.Connector, func(), error), connopts *DbConnectionOptions) *stdlibConnectorContainer { + return &stdlibConnectorContainer{getter: getter, connopts: connopts} +} + +type stdlibConnectorContainer struct { + db *sql.DB + mu sync.Mutex + cleanup func() + + getter func() (driver.Connector, func(), error) + connopts *DbConnectionOptions +} + +func (s *stdlibConnectorContainer) Open() (SqlDBTX, error) { + s.mu.Lock() + defer s.mu.Unlock() + connector, cleanup, err := s.getter() + if err != nil { + return nil, err + } + s.cleanup = cleanup + db := sql.OpenDB(connector) + setConnectionOpts(db, s.connopts) + s.db = db + return s.db, err +} +func (s *stdlibConnectorContainer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + db := s.db + cleanup := s.cleanup + if cleanup != nil { + defer cleanup() + } + if db == nil { + return nil + } + s.db = nil + s.cleanup = nil + return db.Close() +} + +type DbConnectionOptions struct { + MaxOpenConns *int + MaxIdleConns *int + + ConnMaxIdleTime *time.Duration + ConnMaxLifetime *time.Duration +} + +func newStdlibContainer(drvr, dsn string, connOpts *DbConnectionOptions) *stdlibContainer { + return &stdlibContainer{driver: drvr, dsn: dsn, connopts: connOpts} +} + +type stdlibContainer struct { + db *sql.DB + mu sync.Mutex + + driver string + dsn string + connopts *DbConnectionOptions +} + +func (s *stdlibContainer) Open() (SqlDBTX, error) { + s.mu.Lock() + defer s.mu.Unlock() + db, err := sql.Open(s.driver, s.dsn) + if err != nil { + return nil, err + } + setConnectionOpts(db, s.connopts) + s.db = db + return db, nil +} + +func setConnectionOpts(db *sql.DB, connopts *DbConnectionOptions) { + if connopts != nil { + if connopts.ConnMaxIdleTime != nil { + db.SetConnMaxIdleTime(*connopts.ConnMaxIdleTime) } - port := 0 - if details.GetPort() != nil { - port = int(*details.GetPort()) + if connopts.ConnMaxLifetime != nil { + db.SetConnMaxLifetime(*connopts.ConnMaxLifetime) + } + if connopts.MaxIdleConns != nil { + db.SetMaxIdleConns(*connopts.MaxIdleConns) + } + if connopts.MaxOpenConns != nil { + db.SetMaxOpenConns(*connopts.MaxOpenConns) } - return sshtunnel.NewEndpointWithUser(details.GetHost(), port, details.GetUser()), nil - default: - return nil, nucleuserrors.NewBadRequest(fmt.Sprintf("must provide valid mssql connection: %T", cc)) } } + +func (s *stdlibContainer) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + db := s.db + if db == nil { + return nil + } + s.db = nil + return db.Close() +} + +type ConnectionDetails struct { + dbconnectconfig.GeneralDbConnectConfig + MaxConnectionLimit *int32 +} + +func (c *ConnectionDetails) String() string { + return c.GeneralDbConnectConfig.String() +} + +type ClientCertConfig struct { + RootCert *string + + ClientCert *string + ClientKey *string +} diff --git a/backend/pkg/sqlconnect/sql-connector_test.go b/backend/pkg/sqlconnect/sql-connector_test.go index 3ea20b1e5a..a0a40e4c3f 100644 --- a/backend/pkg/sqlconnect/sql-connector_test.go +++ b/backend/pkg/sqlconnect/sql-connector_test.go @@ -1,7 +1,6 @@ package sqlconnect import ( - "log/slog" "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" @@ -28,51 +27,55 @@ var ( } mssqlconnection = "sqlserver://sa:YourStrong@Passw0rd@localhost:1433?database=master" -) -func Test_NewDbFromConnectionConfig(t *testing.T) { - c := &SqlOpenConnector{} - sqldb, err := c.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ - Connection: mysqlconnection, + tunnel = &mgmtv1alpha1.SSHTunnel{ + Host: "localhost", + Port: 2222, + User: "foo", + KnownHostPublicKey: nil, + Authentication: &mgmtv1alpha1.SSHAuthentication{ + AuthConfig: &mgmtv1alpha1.SSHAuthentication_Passphrase{ + Passphrase: &mgmtv1alpha1.SSHPassphrase{ + Value: "foo", }, }, }, - }, nil, nil) - assert.NoError(t, err) - assert.NotNil(t, sqldb) -} - -func Test_NewDbFromConnectionConfig_BadConfig(t *testing.T) { - c := &SqlOpenConnector{} - sqldb, err := c.NewDbFromConnectionConfig(nil, nil, nil) - assert.Error(t, err) - assert.Nil(t, sqldb) -} + } +) -func Test_NewPgPoolFromConnectionConfig(t *testing.T) { - c := &SqlOpenConnector{} - sqldb, err := c.NewPgPoolFromConnectionConfig(&mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: pgconnection, - }, - }, nil, nil) - assert.NoError(t, err) - assert.NotNil(t, sqldb) -} +func Test_NewDbFromConnectionConfig(t *testing.T) { + connector := &SqlOpenConnector{} + t.Run("mysql", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ + Connection: mysqlconnection, + }, + }, + }, + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) -func Test_NewPgPoolFromConnectionConfig_BadConfig(t *testing.T) { - c := &SqlOpenConnector{} - sqldb, err := c.NewPgPoolFromConnectionConfig(nil, nil, nil) - assert.Error(t, err) - assert.Nil(t, sqldb) -} + t.Run("mysql tunnel", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ + MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ + ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ + Connection: mysqlconnection, + }, + Tunnel: tunnel, + }, + }, + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) -func Test_getConnectionDetails_Pg_NoTunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ + t.Run("pg", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ @@ -80,101 +83,28 @@ func Test_getConnectionDetails_Pg_NoTunnel(t *testing.T) { }, }, }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.Nil(t, out.Tunnel) -} + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) -func Test_getConnectionDetails_Pg_Tunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ + t.Run("pg tunnel", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ Connection: pgconnection, }, - Tunnel: &mgmtv1alpha1.SSHTunnel{ - Host: "bastion.neosync.dev", - Port: 22, - User: "testuser", - Authentication: nil, - KnownHostPublicKey: nil, - }, + Tunnel: tunnel, }, }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.NotNil(t, out.Tunnel) - assert.Equal(t, out.GeneralDbConnectConfig.GetHost(), "localhost") - assert.Equal(t, *out.GeneralDbConnectConfig.GetPort(), 0) -} - -func Test_getConnectionDetails_Mysql_NoTunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ - Connection: mysqlconnection, - }, - }, - }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.Nil(t, out.Tunnel) -} - -func Test_getConnectionDetails_Mysql_Tunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ - Connection: mysqlconnection, - }, - Tunnel: &mgmtv1alpha1.SSHTunnel{ - Host: "bastion.neosync.dev", - Port: 22, - User: "testuser", - Authentication: nil, - KnownHostPublicKey: nil, - }, - }, - }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.NotNil(t, out.Tunnel) - assert.Equal(t, out.GeneralDbConnectConfig.GetHost(), "localhost") - assert.Equal(t, *out.GeneralDbConnectConfig.GetPort(), 0) -} + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) -func Test_getConnectionDetails_Mssql_NoTunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ + t.Run("mssql", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ @@ -182,47 +112,49 @@ func Test_getConnectionDetails_Mssql_NoTunnel(t *testing.T) { }, }, }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.Nil(t, out.Tunnel) -} + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) -func Test_getConnectionDetails_Mssql_Tunnel(t *testing.T) { - out, err := GetConnectionDetails( - &mgmtv1alpha1.ConnectionConfig{ + t.Run("mssql tunnel", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(&mgmtv1alpha1.ConnectionConfig{ Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ ConnectionConfig: &mgmtv1alpha1.MssqlConnectionConfig_Url{ Url: mssqlconnection, }, - Tunnel: &mgmtv1alpha1.SSHTunnel{ - Host: "bastion.neosync.dev", - Port: 22, - User: "testuser", - Authentication: nil, - KnownHostPublicKey: nil, - }, + Tunnel: tunnel, }, }, - }, - ptr(uint32(5)), - nil, - slog.Default(), - ) - assert.NoError(t, err) - assert.NotNil(t, out) - assert.NotNil(t, out.GeneralDbConnectConfig) - assert.NotNil(t, out.Tunnel) - assert.Equal(t, out.GeneralDbConnectConfig.GetHost(), "localhost") - assert.Equal(t, *out.GeneralDbConnectConfig.GetPort(), 0) + }, nil, nil) + assert.NoError(t, err) + assert.NotNil(t, sqldb) + }) + + t.Run("invalid", func(t *testing.T) { + sqldb, err := connector.NewDbFromConnectionConfig(nil, nil, nil) + assert.Error(t, err) + assert.Nil(t, sqldb) + }) } func ptr[T any](val T) *T { return &val } + +func Test_getSshAddr(t *testing.T) { + t.Run("with port", func(t *testing.T) { + actual := getSshAddr(&mgmtv1alpha1.SSHTunnel{ + Host: "localhost", + Port: 2222, + }) + assert.Equal(t, "localhost:2222", actual) + }) + t.Run("without port", func(t *testing.T) { + actual := getSshAddr(&mgmtv1alpha1.SSHTunnel{ + Host: "localhost", + }) + assert.Equal(t, "localhost", actual) + }) +} diff --git a/backend/pkg/sqlconnect/sqldb.go b/backend/pkg/sqlconnect/sqldb.go deleted file mode 100644 index 85b5c65479..0000000000 --- a/backend/pkg/sqlconnect/sqldb.go +++ /dev/null @@ -1,90 +0,0 @@ -package sqlconnect - -import ( - "database/sql" - slog "log/slog" - - "github.com/nucleuscloud/neosync/backend/pkg/sshtunnel" -) - -// interface used by SqlConnector to abstract away the opening and closing of a sqldb that includes tunneling -type SqlDbContainer interface { - Open() (SqlDBTX, error) - Close() error -} - -type SqlDb struct { - db *sql.DB - - details *ConnectionDetails - - tunnel *sshtunnel.Sshtunnel - - dsn string - - logger *slog.Logger -} - -func newSqlDb(details *ConnectionDetails, logger *slog.Logger) *SqlDb { - return &SqlDb{details: details, logger: logger} -} - -func (s *SqlDb) Open() (SqlDBTX, error) { - if s.details.Tunnel != nil { - ready, err := s.details.Tunnel.Start(s.logger) - if err != nil { - return nil, err - } - <-ready - - _, localport := s.details.Tunnel.GetLocalHostPort() - newPort := int32(localport) //nolint:gosec // Ignoring for now - s.details.GeneralDbConnectConfig.SetPort(newPort) - dsn := s.details.GeneralDbConnectConfig.String() - - db, err := sql.Open(s.details.GeneralDbConnectConfig.GetDriver(), dsn) - if err != nil { - s.details.Tunnel.Close() - return nil, err - } - // set max number of connections. - if s.details.MaxConnectionLimit != nil { - db.SetMaxOpenConns(int(*s.details.MaxConnectionLimit)) - } - s.db = db - s.dsn = dsn - s.tunnel = s.details.Tunnel - return db, nil - } - dsn := s.details.GeneralDbConnectConfig.String() - db, err := sql.Open(s.details.GeneralDbConnectConfig.GetDriver(), dsn) - s.db = db - if err != nil { - return nil, err - } - s.dsn = dsn - // set max number of connections. - if s.details.MaxConnectionLimit != nil { - db.SetMaxOpenConns(int(*s.details.MaxConnectionLimit)) - } - return db, nil -} - -func (s *SqlDb) GetDsn() string { - return s.dsn -} - -func (s *SqlDb) Close() error { - if s.db == nil { - return nil - } - db := s.db - s.dsn = "" - s.db = nil - err := db.Close() - if s.tunnel != nil { - s.tunnel.Close() - s.tunnel = nil - } - return err -} diff --git a/backend/pkg/sqlmanager/mssql/integration_test.go b/backend/pkg/sqlmanager/mssql/integration_test.go index ca1e8b0033..d1ade01e95 100644 --- a/backend/pkg/sqlmanager/mssql/integration_test.go +++ b/backend/pkg/sqlmanager/mssql/integration_test.go @@ -13,7 +13,6 @@ import ( "sync" "testing" - _ "github.com/microsoft/go-mssqldb" "golang.org/x/sync/errgroup" mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier" diff --git a/backend/pkg/sqlmanager/mssql/mssql-manager_integration_test.go b/backend/pkg/sqlmanager/mssql/mssql-manager_integration_test.go index d83ca603a5..1df7c59430 100644 --- a/backend/pkg/sqlmanager/mssql/mssql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/mssql/mssql-manager_integration_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + _ "github.com/microsoft/go-mssqldb" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" "github.com/stretchr/testify/require" ) diff --git a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go index 0b650e0b60..754f1e0504 100644 --- a/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/mysql/mysql-manager_integration_test.go @@ -5,6 +5,7 @@ import ( "fmt" "testing" + _ "github.com/go-sql-driver/mysql" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" "github.com/stretchr/testify/require" ) diff --git a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go index ae64662ba3..5cc081f01d 100644 --- a/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/mysql_sql-manager_integration_test.go @@ -15,6 +15,7 @@ import ( "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" "github.com/stretchr/testify/suite" + _ "github.com/go-sql-driver/mysql" "github.com/testcontainers/testcontainers-go" testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" "github.com/testcontainers/testcontainers-go/wait" diff --git a/backend/pkg/sqlmanager/postgres/integration_test.go b/backend/pkg/sqlmanager/postgres/integration_test.go index 08d6edd7b9..2f082cb4c7 100644 --- a/backend/pkg/sqlmanager/postgres/integration_test.go +++ b/backend/pkg/sqlmanager/postgres/integration_test.go @@ -2,13 +2,13 @@ package sqlmanager_postgres import ( "context" + "database/sql" "fmt" "log/slog" "os" "testing" "time" - "github.com/jackc/pgx/v5/pgxpool" pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" "github.com/stretchr/testify/suite" "github.com/testcontainers/testcontainers-go" @@ -19,7 +19,7 @@ import ( type IntegrationTestSuite struct { suite.Suite - pgpool *pgxpool.Pool + db *sql.DB querier pg_queries.Querier setupSql string @@ -69,32 +69,32 @@ func (s *IntegrationTestSuite) SetupSuite() { } s.teardownSql = string(teardownSql) - pool, err := pgxpool.New(s.ctx, connstr) + db, err := sql.Open("pgx", connstr) if err != nil { panic(err) } - s.pgpool = pool + s.db = db s.querier = pg_queries.New() } // Runs before each test func (s *IntegrationTestSuite) SetupTest() { - _, err := s.pgpool.Exec(s.ctx, s.setupSql) + _, err := s.db.ExecContext(s.ctx, s.setupSql) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownTest() { - _, err := s.pgpool.Exec(s.ctx, s.teardownSql) + _, err := s.db.ExecContext(s.ctx, s.teardownSql) if err != nil { panic(err) } } func (s *IntegrationTestSuite) TearDownSuite() { - if s.pgpool != nil { - s.pgpool.Close() + if s.db != nil { + s.db.Close() } if s.pgcontainer != nil { err := s.pgcontainer.Terminate(s.ctx) diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager.go b/backend/pkg/sqlmanager/postgres/postgres-manager.go index fa637fefe7..81202573c8 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager.go @@ -17,16 +17,16 @@ import ( type PostgresManager struct { querier pg_queries.Querier - pool pg_queries.DBTX + db pg_queries.DBTX close func() } -func NewManager(querier pg_queries.Querier, pool pg_queries.DBTX, closer func()) *PostgresManager { - return &PostgresManager{querier: querier, pool: pool, close: closer} +func NewManager(querier pg_queries.Querier, db pg_queries.DBTX, closer func()) *PostgresManager { + return &PostgresManager{querier: querier, db: db, close: closer} } func (p *PostgresManager) GetDatabaseSchema(ctx context.Context) ([]*sqlmanager_shared.DatabaseSchemaRow, error) { - dbSchemas, err := p.querier.GetDatabaseSchema(ctx, p.pool) + dbSchemas, err := p.querier.GetDatabaseSchema(ctx, p.db) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -77,7 +77,7 @@ func (p *PostgresManager) GetTableConstraintsBySchema(ctx context.Context, schem if len(schemas) == 0 { return &sqlmanager_shared.TableConstraints{}, nil } - rows, err := p.querier.GetTableConstraintsBySchema(ctx, p.pool, schemas) + rows, err := p.querier.GetTableConstraintsBySchema(ctx, p.db, schemas) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -124,7 +124,7 @@ func (p *PostgresManager) GetTableConstraintsBySchema(ctx context.Context, schem } func (p *PostgresManager) GetRolePermissionsMap(ctx context.Context) (map[string][]string, error) { - rows, err := p.querier.GetPostgresRolePermissions(ctx, p.pool) + rows, err := p.querier.GetPostgresRolePermissions(ctx, p.db) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -146,7 +146,7 @@ func (p *PostgresManager) GetCreateTableStatement(ctx context.Context, schema, t var tableSchemas []*pg_queries.GetDatabaseTableSchemasBySchemasAndTablesRow errgrp.Go(func() error { - result, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, p.pool, []string{schematable.String()}) + result, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, p.db, []string{schematable.String()}) if err != nil { return fmt.Errorf("unable to generate database table schema: %w", err) } @@ -155,7 +155,7 @@ func (p *PostgresManager) GetCreateTableStatement(ctx context.Context, schema, t }) var tableConstraints []*pg_queries.GetTableConstraintsRow errgrp.Go(func() error { - result, err := p.querier.GetTableConstraints(errctx, p.pool, &pg_queries.GetTableConstraintsParams{ + result, err := p.querier.GetTableConstraints(errctx, p.db, &pg_queries.GetTableConstraintsParams{ Schema: schema, Table: table, }) @@ -187,7 +187,7 @@ func (p *PostgresManager) GetSchemaTableTriggers(ctx context.Context, tables []* combined = append(combined, t.String()) } - rows, err := p.querier.GetCustomTriggersBySchemaAndTables(ctx, p.pool, combined) + rows, err := p.querier.GetCustomTriggersBySchemaAndTables(ctx, p.db, combined) if err != nil && !neosyncdb.IsNoRows(err) { return nil, err } else if err != nil && neosyncdb.IsNoRows(err) { @@ -268,7 +268,7 @@ func (p *PostgresManager) GetSchemaTableDataTypes(ctx context.Context, tables [] } func (p *PostgresManager) GetSequencesByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { - rows, err := p.querier.GetCustomSequencesBySchemaAndTables(ctx, p.pool, &pg_queries.GetCustomSequencesBySchemaAndTablesParams{ + rows, err := p.querier.GetCustomSequencesBySchemaAndTables(ctx, p.db, &pg_queries.GetCustomSequencesBySchemaAndTablesParams{ Schema: schema, Tables: tables, }) @@ -290,7 +290,7 @@ func (p *PostgresManager) GetSequencesByTables(ctx context.Context, schema strin } func (p *PostgresManager) getFunctionsByTables(ctx context.Context, schema string, tables []string) ([]*sqlmanager_shared.DataType, error) { - rows, err := p.querier.GetCustomFunctionsBySchemaAndTables(ctx, p.pool, &pg_queries.GetCustomFunctionsBySchemaAndTablesParams{ + rows, err := p.querier.GetCustomFunctionsBySchemaAndTables(ctx, p.db, &pg_queries.GetCustomFunctionsBySchemaAndTablesParams{ Schema: schema, Tables: tables, }) @@ -318,7 +318,7 @@ type datatypes struct { } func (p *PostgresManager) getDataTypesByTables(ctx context.Context, schema string, tables []string) (*datatypes, error) { - rows, err := p.querier.GetDataTypesBySchemaAndTables(ctx, p.pool, &pg_queries.GetDataTypesBySchemaAndTablesParams{ + rows, err := p.querier.GetDataTypesBySchemaAndTables(ctx, p.db, &pg_queries.GetDataTypesBySchemaAndTablesParams{ Schema: schema, Tables: tables, }) @@ -368,7 +368,7 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* colDefMap := map[string][]*pg_queries.GetDatabaseTableSchemasBySchemasAndTablesRow{} errgrp.Go(func() error { - columnDefs, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, p.pool, combined) + columnDefs, err := p.querier.GetDatabaseTableSchemasBySchemasAndTables(errctx, p.db, combined) if err != nil { return err } @@ -381,7 +381,7 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* constraintmap := map[string][]*pg_queries.GetTableConstraintsBySchemaRow{} errgrp.Go(func() error { - constraints, err := p.querier.GetTableConstraintsBySchema(errctx, p.pool, schemas) // todo: update this to only grab what is necessary instead of entire schema + constraints, err := p.querier.GetTableConstraintsBySchema(errctx, p.db, schemas) // todo: update this to only grab what is necessary instead of entire schema if err != nil { return err } @@ -394,7 +394,7 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* indexmap := map[string][]string{} errgrp.Go(func() error { - idxrecords, err := p.querier.GetIndicesBySchemasAndTables(errctx, p.pool, combined) + idxrecords, err := p.querier.GetIndicesBySchemasAndTables(errctx, p.db, combined) if err != nil { return err } @@ -421,15 +421,15 @@ func (p *PostgresManager) GetTableInitStatements(ctx context.Context, tables []* for _, record := range tableData { record := record var seqConfig *SequenceConfiguration - if record.IdentityGeneration != "" && record.SeqStartValue != nil && record.SeqMinValue != nil && - record.SeqMaxValue != nil && record.SeqIncrementBy != nil && record.SeqCycleOption != nil && record.SeqCacheValue != nil { + if record.IdentityGeneration != "" && record.SeqStartValue.Valid && record.SeqMinValue.Valid && + record.SeqMaxValue.Valid && record.SeqIncrementBy.Valid && record.SeqCycleOption.Valid && record.SeqCacheValue.Valid { seqConfig = &SequenceConfiguration{ - StartValue: *record.SeqStartValue, - MinValue: *record.SeqMinValue, - MaxValue: *record.SeqMaxValue, - IncrementBy: *record.SeqIncrementBy, - CycleOption: *record.SeqCycleOption, - CacheValue: *record.SeqCacheValue, + StartValue: record.SeqStartValue.Int64, + MinValue: record.SeqMinValue.Int64, + MaxValue: record.SeqMaxValue.Int64, + IncrementBy: record.SeqIncrementBy.Int64, + CycleOption: record.SeqCycleOption.Bool, + CacheValue: record.SeqCacheValue.Int64, } } columns = append(columns, buildTableCol(&buildTableColRequest{ @@ -707,15 +707,15 @@ func generateCreateTableStatement( for idx := range tableSchemas { record := tableSchemas[idx] var seqConfig *SequenceConfiguration - if record.IdentityGeneration != "" && record.SeqStartValue != nil && record.SeqMinValue != nil && - record.SeqMaxValue != nil && record.SeqIncrementBy != nil && record.SeqCycleOption != nil && record.SeqCacheValue != nil { + if record.IdentityGeneration != "" && record.SeqStartValue.Valid && record.SeqMinValue.Valid && + record.SeqMaxValue.Valid && record.SeqIncrementBy.Valid && record.SeqCycleOption.Valid && record.SeqCacheValue.Valid { seqConfig = &SequenceConfiguration{ - StartValue: *record.SeqStartValue, - MinValue: *record.SeqMinValue, - MaxValue: *record.SeqMaxValue, - IncrementBy: *record.SeqIncrementBy, - CycleOption: *record.SeqCycleOption, - CacheValue: *record.SeqCacheValue, + StartValue: record.SeqStartValue.Int64, + MinValue: record.SeqMinValue.Int64, + MaxValue: record.SeqMaxValue.Int64, + IncrementBy: record.SeqIncrementBy.Int64, + CycleOption: record.SeqCycleOption.Bool, + CacheValue: record.SeqCacheValue.Int64, } } columns[idx] = buildTableCol(&buildTableColRequest{ @@ -825,7 +825,7 @@ func (p *PostgresManager) BatchExec(ctx context.Context, batchSize int, statemen batchCmd = fmt.Sprintf("%s %s", *opts.Prefix, batchCmd) } - _, err := p.pool.Exec(ctx, batchCmd) + _, err := p.db.ExecContext(ctx, batchCmd) if err != nil { return err } @@ -834,7 +834,7 @@ func (p *PostgresManager) BatchExec(ctx context.Context, batchSize int, statemen } func (p *PostgresManager) Exec(ctx context.Context, statement string) error { - _, err := p.pool.Exec(ctx, statement) + _, err := p.db.ExecContext(ctx, statement) if err != nil { return err } @@ -842,7 +842,7 @@ func (p *PostgresManager) Exec(ctx context.Context, statement string) error { } func (p *PostgresManager) Close() { - if p.pool != nil && p.close != nil { + if p.db != nil && p.close != nil { p.close() } } @@ -865,7 +865,7 @@ func (p *PostgresManager) GetTableRowCount( return 0, err } var count int64 - err = p.pool.QueryRow(ctx, sql).Scan(&count) + err = p.db.QueryRowContext(ctx, sql).Scan(&count) if err != nil { return 0, err } diff --git a/backend/pkg/sqlmanager/postgres/postgres-manager_integration_test.go b/backend/pkg/sqlmanager/postgres/postgres-manager_integration_test.go index 5e89a199d2..32c2756657 100644 --- a/backend/pkg/sqlmanager/postgres/postgres-manager_integration_test.go +++ b/backend/pkg/sqlmanager/postgres/postgres-manager_integration_test.go @@ -7,12 +7,13 @@ import ( "github.com/doug-martin/goqu/v9" _ "github.com/doug-martin/goqu/v9/dialect/postgres" + _ "github.com/jackc/pgx/v5/stdlib" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" "github.com/stretchr/testify/require" ) func (s *IntegrationTestSuite) Test_GetDatabaseSchema() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} expectedSubset := []*sqlmanager_shared.DatabaseSchemaRow{ { @@ -35,7 +36,7 @@ func (s *IntegrationTestSuite) Test_GetDatabaseSchema() { } func (s *IntegrationTestSuite) Test_GetDatabaseSchema_With_Identity() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} expectedSubset := []*sqlmanager_shared.DatabaseSchemaRow{ { @@ -59,7 +60,7 @@ func (s *IntegrationTestSuite) Test_GetDatabaseSchema_With_Identity() { } func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetSchemaColumnMap(context.Background()) require.NoError(s.T(), err) @@ -74,7 +75,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaColumnMap() { } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{s.schema}) require.NoError(s.T(), err) @@ -130,7 +131,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{s.schema}) require.NoError(s.T(), err) @@ -171,7 +172,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_BasicCircular() } func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} actual, err := manager.GetTableConstraintsBySchema(s.ctx, []string{s.schema}) require.NoError(s.T(), err) @@ -190,7 +191,7 @@ func (s *IntegrationTestSuite) Test_GetForeignKeyConstraintsMap_Composite() { } func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{s.schema}) require.NoError(s.T(), err) @@ -207,7 +208,7 @@ func (s *IntegrationTestSuite) Test_GetPrimaryKeyConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{s.schema}) require.NoError(s.T(), err) @@ -220,7 +221,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap() { } func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetTableConstraintsBySchema(context.Background(), []string{s.schema}) require.NoError(s.T(), err) @@ -235,7 +236,7 @@ func (s *IntegrationTestSuite) Test_GetUniqueConstraintsMap_Composite() { } func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetRolePermissionsMap(context.Background()) require.NoError(s.T(), err) @@ -253,7 +254,7 @@ func (s *IntegrationTestSuite) Test_GetRolePermissionsMap() { } func (s *IntegrationTestSuite) Test_GetCreateTableStatement() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetCreateTableStatement(context.Background(), s.schema, "users") require.NoError(s.T(), err) @@ -262,7 +263,7 @@ func (s *IntegrationTestSuite) Test_GetCreateTableStatement() { } func (s *IntegrationTestSuite) Test_GetTableInitStatements() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) actual, err := manager.GetTableInitStatements( context.Background(), @@ -274,7 +275,7 @@ func (s *IntegrationTestSuite) Test_GetTableInitStatements() { } func (s *IntegrationTestSuite) Test_Exec() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) sql, _, err := goqu.Dialect("postgres").Select("*").From(goqu.T("users").Schema(s.schema)).ToSQL() require.NoError(s.T(), err) @@ -284,7 +285,7 @@ func (s *IntegrationTestSuite) Test_Exec() { } func (s *IntegrationTestSuite) Test_BatchExec() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) sql, _, err := goqu.Dialect("postgres").Select("*").From(goqu.T("users").Schema(s.schema)).ToSQL() require.NoError(s.T(), err) @@ -295,7 +296,7 @@ func (s *IntegrationTestSuite) Test_BatchExec() { } func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) sql, _, err := goqu.Dialect("postgres").Select("*").From(goqu.T("users").Schema(s.schema)).ToSQL() require.NoError(s.T(), err) sql += ";" @@ -307,7 +308,7 @@ func (s *IntegrationTestSuite) Test_BatchExec_With_Prefix() { } func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: s.schema, Table: "parent1"}}) require.NoError(s.T(), err) @@ -315,7 +316,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements() { } func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) statements, err := manager.GetSchemaInitStatements(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: s.schema, Table: "custom_table"}}) require.NoError(s.T(), err) @@ -323,7 +324,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaInitStatements_customtable() { } func (s *IntegrationTestSuite) Test_GetSchemaTableDataTypes_customtable() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) resp, err := manager.GetSchemaTableDataTypes(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: s.schema, Table: "custom_table"}}) require.NoError(s.T(), err) @@ -338,7 +339,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaTableDataTypes_customtable() { } func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) triggers, err := manager.GetSchemaTableTriggers(context.Background(), []*sqlmanager_shared.SchemaTable{{Schema: s.schema, Table: "custom_table"}}) require.NoError(s.T(), err) @@ -346,7 +347,7 @@ func (s *IntegrationTestSuite) Test_GetSchemaTableTriggers_customtable() { } func (s *IntegrationTestSuite) Test_GetTableRowCount() { - manager := NewManager(s.querier, s.pgpool, func() {}) + manager := NewManager(s.querier, s.db, func() {}) table := "tablewithcount" @@ -378,7 +379,7 @@ type testColumnProperties struct { } func (s *IntegrationTestSuite) Test_GetPostgresColumnOverrideAndResetProperties() { - manager := PostgresManager{querier: s.querier, pool: s.pgpool} + manager := PostgresManager{querier: s.querier, db: s.db} colInfoMap, err := manager.GetSchemaColumnMap(context.Background()) require.NoError(s.T(), err) diff --git a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go index 64b7dce55f..04354c32f9 100644 --- a/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go +++ b/backend/pkg/sqlmanager/postgres_sql-manager_integration_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + _ "github.com/jackc/pgx/v5/stdlib" "github.com/testcontainers/testcontainers-go" testpg "github.com/testcontainers/testcontainers-go/modules/postgres" "github.com/testcontainers/testcontainers-go/wait" diff --git a/backend/pkg/sqlmanager/sql-manager.go b/backend/pkg/sqlmanager/sql-manager.go index f3a49121a9..cd26f6bb9b 100644 --- a/backend/pkg/sqlmanager/sql-manager.go +++ b/backend/pkg/sqlmanager/sql-manager.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgxpool" + "github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5/tracelog" mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" @@ -21,10 +21,6 @@ import ( sqlmanager_mysql "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/mysql" sqlmanager_postgres "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/postgres" sqlmanager_shared "github.com/nucleuscloud/neosync/backend/pkg/sqlmanager/shared" - - _ "github.com/go-sql-driver/mysql" - _ "github.com/jackc/pgx/v5/stdlib" - // _ "github.com/microsoft/go-mssqldb" // This is commented out because one of our dependencies is importing this already and it panics if called more than once. ) type SqlDatabase interface { @@ -119,15 +115,11 @@ func (s *SqlManager) NewPooledSqlDb( case *mgmtv1alpha1.ConnectionConfig_PgConfig: var closer func() if _, ok := s.pgpool.Load(connection.Id); !ok { - pgconfig := connection.ConnectionConfig.GetPgConfig() - if pgconfig == nil { - return nil, fmt.Errorf("source connection (%s) is not a postgres config", connection.Id) - } - pgconn, err := s.sqlconnector.NewPgPoolFromConnectionConfig(pgconfig, sqlmanager_shared.Ptr(uint32(5)), slogger) + pgconn, err := s.sqlconnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), sqlmanager_shared.Ptr(uint32(5)), slogger) if err != nil { return nil, fmt.Errorf("unable to create new postgres pool from connection config: %w", err) } - pool, err := pgconn.Open(ctx) + pool, err := pgconn.Open() if err != nil { return nil, fmt.Errorf("unable to open postgres connection: %w", err) } @@ -245,11 +237,11 @@ func (s *SqlManager) NewSqlDbFromConnectionConfig( if pgconfig == nil { return nil, fmt.Errorf("source connection is not a postgres config") } - pgconn, err := s.sqlconnector.NewPgPoolFromConnectionConfig(pgconfig, connTimeout, slogger) + pgconn, err := s.sqlconnector.NewDbFromConnectionConfig(connectionConfig, connTimeout, slogger) if err != nil { return nil, fmt.Errorf("unable to create new postgres pool from connection config: %w", err) } - pool, err := pgconn.Open(ctx) + pool, err := pgconn.Open() if err != nil { return nil, fmt.Errorf("unable to open postgres connection: %w", err) } @@ -312,22 +304,23 @@ func (s *SqlManager) NewSqlDbFromUrl( var db SqlDatabase switch driver { case sqlmanager_shared.PostgresDriver, "postgres": - pgxconfig, err := pgxpool.ParseConfig(connectionUrl) + pgxconfig, err := pgx.ParseConfig(connectionUrl) if err != nil { return nil, err } - pgxconfig.ConnConfig.Tracer = &tracelog.TraceLog{ + pgxconfig.Tracer = &tracelog.TraceLog{ Logger: pgxslog.NewLogger(slog.Default(), pgxslog.GetShouldOmitArgs()), LogLevel: pgxslog.GetDatabaseLogLevel(), } - pgxconfig.ConnConfig.DefaultQueryExecMode = pgx.QueryExecModeExec - pgconn, err := pgxpool.NewWithConfig(ctx, pgxconfig) + pgxconfig.DefaultQueryExecMode = pgx.QueryExecModeExec + pgconn, err := pgx.ConnectConfig(ctx, pgxconfig) if err != nil { return nil, err } - db = sqlmanager_postgres.NewManager(s.pgquerier, pgconn, func() { + sqldb := stdlib.OpenDB(*pgxconfig) + db = sqlmanager_postgres.NewManager(s.pgquerier, sqldb, func() { if pgconn != nil { - pgconn.Close() + sqldb.Close() } }) driver = sqlmanager_shared.PostgresDriver diff --git a/backend/pkg/sshtunnel/endpoint.go b/backend/pkg/sshtunnel/endpoint.go deleted file mode 100644 index a9c6b7ded9..0000000000 --- a/backend/pkg/sshtunnel/endpoint.go +++ /dev/null @@ -1,32 +0,0 @@ -package sshtunnel - -import ( - "fmt" -) - -type Endpoint struct { - Host string - Port int - User string -} - -// NewEndpoint creates an Endpoint from a string that contains a user, host and -// port. Both User and Port are optional (depending on context). The host can -// be a domain name, IPv4 address or IPv6 address. If it's an IPv6, it must be -// enclosed in square brackets -func NewEndpointWithUser(host string, port int, user string) *Endpoint { - return &Endpoint{ - Host: host, - Port: port, - User: user, - } -} - -func NewEndpoint(host string, port int) *Endpoint { - return &Endpoint{Host: host, Port: port} -} - -// Returns the stringified endpoint sans user -func (endpoint *Endpoint) String() string { - return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) -} diff --git a/backend/pkg/sshtunnel/endpoint_test.go b/backend/pkg/sshtunnel/endpoint_test.go deleted file mode 100644 index e4a1fb3a73..0000000000 --- a/backend/pkg/sshtunnel/endpoint_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package sshtunnel - -import ( - "testing" - - "github.com/zeebo/assert" -) - -func Test_NewEndpointWithUser(t *testing.T) { - assert.Equal( - t, - NewEndpointWithUser("localhost", 5432, "nick"), - &Endpoint{Host: "localhost", Port: 5432, User: "nick"}, - ) -} - -func Test_NewEndpoint(t *testing.T) { - assert.Equal( - t, - NewEndpoint("localhost", 5432), - &Endpoint{Host: "localhost", Port: 5432, User: ""}, - ) -} - -func Test_Endpoint_String(t *testing.T) { - type testcase struct { - name string - input Endpoint - expected string - } - tesstcases := []testcase{ - {name: "empty", input: Endpoint{}, expected: ":0"}, - {name: "host", input: Endpoint{Host: "foo"}, expected: "foo:0"}, - {name: "host+port", input: Endpoint{Host: "foo", Port: 4}, expected: "foo:4"}, - {name: "host+port+user, does not attach username", input: Endpoint{Host: "foo", Port: 4, User: "nick"}, expected: "foo:4"}, - } - for _, tc := range tesstcases { - t.Run(tc.name, func(t *testing.T) { - assert.Equal(t, tc.input.String(), tc.expected) - }) - } -} diff --git a/backend/pkg/sshtunnel/tunnel.go b/backend/pkg/sshtunnel/tunnel.go deleted file mode 100644 index 74fc8be12d..0000000000 --- a/backend/pkg/sshtunnel/tunnel.go +++ /dev/null @@ -1,295 +0,0 @@ -package sshtunnel - -import ( - "errors" - "fmt" - "io" - "log/slog" - "net" - "sync" - "sync/atomic" - "time" - - "github.com/google/uuid" - "golang.org/x/crypto/ssh" -) - -type Sshtunnel struct { - local *Endpoint - server *Endpoint - remote *Endpoint - - config *ssh.ClientConfig - - maxConnectionAttempts uint - close chan any - isOpen atomic.Bool - - shutdowns *sync.Map - - sshclient *ssh.Client - sshMu *sync.RWMutex -} - -func New( - tunnel *Endpoint, - auth ssh.AuthMethod, - destination *Endpoint, - local *Endpoint, - maxConnectionAttempts uint, - serverPublicKey ssh.PublicKey, -) *Sshtunnel { - authmethods := []ssh.AuthMethod{} - if auth != nil { - authmethods = append(authmethods, auth) - } - return &Sshtunnel{ - close: make(chan any), - - local: local, - server: tunnel, - remote: destination, - - maxConnectionAttempts: maxConnectionAttempts, - - config: &ssh.ClientConfig{ - User: tunnel.User, - Auth: authmethods, - HostKeyCallback: getHostKeyCallback(serverPublicKey), - Timeout: 30 * time.Second, - }, - - shutdowns: &sync.Map{}, - - sshMu: &sync.RWMutex{}, - } -} - -// After a tunnel has started, this will return the auto-generated port (if 0 was passed in) -func (t *Sshtunnel) GetLocalHostPort() (host string, port int) { - return t.local.Host, t.local.Port -} - -func getHostKeyCallback(key ssh.PublicKey) ssh.HostKeyCallback { - if key == nil { - return ssh.InsecureIgnoreHostKey() //nolint - } - return ssh.FixedHostKey(key) -} - -func (t *Sshtunnel) Start(logger *slog.Logger) (chan any, error) { - listener, err := net.Listen("tcp", t.local.String()) - if err != nil { - return nil, fmt.Errorf("unable to listen to local endpoint: %w", err) - } - ready := make(chan any) - go t.serve(listener, ready, logger) - return ready, nil -} - -func (t *Sshtunnel) serve(listener net.Listener, ready chan<- any, logger *slog.Logger) { - defer func() { - if err := listener.Close(); err != nil { - if !errors.Is(err, net.ErrClosed) { - logger.Error("failed to close tunnel listener", "error", err) - } - } - }() - - t.local.Port = listener.Addr().(*net.TCPAddr).Port - t.isOpen.Store(true) - close(ready) - - for { - if !t.isOpen.Load() { - break - } - - conn, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - logger.Debug("listener closed, stopping serve loop") - return - } - logger.Error("failed to accept local connection for tunnel", "error", err) - continue - } - - logger.Debug("accepted new connection for tunnel", "remoteAddr", conn.RemoteAddr().String()) - sessionId := uuid.NewString() - shutdown := make(chan any) - t.shutdowns.Store(sessionId, shutdown) - - go func() { - defer func() { - t.shutdowns.Delete(sessionId) - if err := conn.Close(); err != nil { - if !errors.Is(err, net.ErrClosed) { - logger.Error("failed to close tunnel connection for session", "error", err, "sessionId", sessionId) - } - } - }() - select { - case <-t.close: - logger.Debug("received close signal, closing connection", "sessionId", sessionId) - case <-shutdown: - logger.Debug("received shutdown signal for session", "sessionId", sessionId) - default: - t.forward(conn, sessionId, shutdown, logger.With("sessionId", sessionId)) - } - }() - } - - logger.Debug("tunnel closed") -} - -// Takes the local connection, dials into the SSH server, connects to the remote host with that connection, -// and then forwards the traffic from the local connection to the remote connection -func (t *Sshtunnel) forward(localConnection net.Conn, sessionId string, shutdown <-chan any, logger *slog.Logger) { - sshClient, err := t.getSshClient(t.server.String(), t.config, t.maxConnectionAttempts, logger) - if err != nil { - if err := localConnection.Close(); err != nil { - logger.Error(fmt.Sprintf("failed to close local connection: %v", err)) - return - } - logger.Error(fmt.Sprintf("unable to reach SSH server: %v", err)) - return - } - - remoteConnection, err := sshClient.Dial("tcp", t.remote.String()) - if err != nil { - logger.Error(fmt.Sprintf("remote dial error: %s", err)) - if err := sshClient.Close(); err != nil { - logger.Error(fmt.Sprintf("failed to close server connection: %v", err)) - } - if err := localConnection.Close(); err != nil { - logger.Error(fmt.Sprintf("failed to close local connection: %v", err)) - } - return - } - logger.Debug(fmt.Sprintf("connected to %s", t.remote.String())) - - // buffering so that we don't block the copyConnection when it sends its result - done := make(chan error, 2) - go func() { - select { - case <-shutdown: - logger.Debug("issued shutdown of tunnel") - localConnection.Close() - remoteConnection.Close() - t.closeSshClient() - logger.Debug("issued shutdown, closing local, remove, and ssh connections") - case <-done: - t.shutdowns.Delete(sessionId) - localConnection.Close() - remoteConnection.Close() - logger.Debug("connection done, closed local and remote connections") - } - }() - - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - done <- copyConnection(localConnection, remoteConnection, logger.With("direction", "remote->local")) - }() - go func() { - defer wg.Done() - done <- copyConnection(remoteConnection, localConnection, logger.With("direction", "local->remote")) - }() - wg.Wait() - logger.Debug("tunnel forwarding complete for session") -} - -func (t *Sshtunnel) closeSshClient() { - t.sshMu.Lock() - defer t.sshMu.Unlock() - if t.sshclient == nil { - return - } - client := t.sshclient - t.sshclient = nil - client.Close() -} - -func (s *Sshtunnel) getSshClient( - addr string, - config *ssh.ClientConfig, - maxAttempts uint, - logger *slog.Logger, -) (*ssh.Client, error) { - s.sshMu.RLock() - client := s.sshclient - s.sshMu.RUnlock() - if client != nil { - return client, nil - } - s.sshMu.Lock() - defer s.sshMu.Unlock() - if s.sshclient != nil { - return s.sshclient, nil - } - client, err := getSshClient(addr, config, maxAttempts, logger) - if err != nil { - return nil, err - } - logger.Debug(fmt.Sprintf("conntected to %s", addr)) - s.sshclient = client - return client, nil -} - -func getSshClient( - addr string, - config *ssh.ClientConfig, - maxAttempts uint, - logger *slog.Logger, -) (*ssh.Client, error) { - var sshClient *ssh.Client - var err error - var attemptsLeft = maxAttempts - for { - sshClient, err = ssh.Dial("tcp", addr, config) - if err != nil { - attemptsLeft-- - if attemptsLeft <= 0 { - logger.Error(fmt.Sprintf("server dial error: %v: exceeded %d attempts", err, maxAttempts)) - return nil, err - } - logger.Warn(fmt.Sprintf("server dial error: %v: attempt %d/%d", err, maxAttempts-attemptsLeft, maxAttempts)) - } else { - break - } - } - return sshClient, err -} - -// Writer is what receives the input (dst), reader is what the input is read from (src) -func copyConnection(writer, reader net.Conn, logger *slog.Logger) error { - _, err := io.Copy(writer, reader) - if err != nil { - if errors.Is(err, net.ErrClosed) { - // This can be a common error if the thing using the ssh connection was abruptly closed. - // This is common if a user is trying to test their database connection, but they've given Neosync bad credentials - // or something else that causes the server to force close the client connection - logger.Warn("connection was closed before reaching end of input", "error", err) - } else { - logger.Error("unexpected error while streaming through tunnel", "error", err) - } - } else { - logger.Debug("ssh tunnel stream completed successfully") - } - return err -} - -func (t *Sshtunnel) Close() { - if !t.isOpen.CompareAndSwap(true, false) { - return - } - close(t.close) - t.shutdowns.Range(func(key, value any) bool { - if ch, ok := value.(chan any); ok { - close(ch) - } - return true - }) -} diff --git a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go index 372c3b1499..f3f6b3b040 100644 --- a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go +++ b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data.go @@ -18,7 +18,6 @@ import ( "github.com/aws/aws-sdk-go-v2/service/dynamodb" dynamotypes "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" "github.com/aws/aws-sdk-go-v2/service/s3" - "github.com/gofrs/uuid" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" logger_interceptor "github.com/nucleuscloud/neosync/backend/internal/connect/interceptors/logger" nucleuserrors "github.com/nucleuscloud/neosync/backend/internal/errors" @@ -148,11 +147,11 @@ func (s *Service) GetConnectionDataStream( return err } - conn, err := s.sqlConnector.NewPgPoolFromConnectionConfig(config.PgConfig, &connectionTimeout, logger) + conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), &connectionTimeout, logger) if err != nil { return err } - db, err := conn.Open(ctx) + db, err := conn.Open() if err != nil { return err } @@ -164,65 +163,41 @@ func (s *Service) GetConnectionDataStream( if err != nil { return err } - r, err := db.Query(ctx, query) + r, err := db.QueryContext(ctx, query) if err != nil && !neosyncdb.IsNoRows(err) { return err } defer r.Close() - columnNames := []string{} - for _, col := range r.FieldDescriptions() { - columnNames = append(columnNames, col.Name) + columnNames, err := r.Columns() + if err != nil { + return err } selectQuery, err := querybuilder.BuildSelectQuery("postgres", table, columnNames, nil) if err != nil { return err } - rows, err := db.Query(ctx, selectQuery) + rows, err := db.QueryContext(ctx, selectQuery) if err != nil && !neosyncdb.IsNoRows(err) { return err } defer rows.Close() + // todo: this is probably way fucking broken now for rows.Next() { values := make([][]byte, len(columnNames)) valuesWrapped := make([]any, 0, len(columnNames)) - - for i, col := range r.FieldDescriptions() { - if col.DataTypeOID == 1082 { // OID for date - var t time.Time - ds := DateScanner{val: &t} - valuesWrapped = append(valuesWrapped, &ds) - } else { - valuesWrapped = append(valuesWrapped, &values[i]) - } + for i := range values { + valuesWrapped = append(valuesWrapped, &values[i]) } - if err := rows.Scan(valuesWrapped...); err != nil { return err } row := map[string][]byte{} for i, v := range values { col := columnNames[i] - if r.FieldDescriptions()[i].DataTypeOID == 1082 { // OID for date - // Convert time.Time value to []byte - if ds, ok := valuesWrapped[i].(*DateScanner); ok && ds.val != nil { - row[col] = []byte(ds.val.Format(time.RFC3339)) - } else { - row[col] = v - } - } else if r.FieldDescriptions()[i].DataTypeOID == 2950 { // OID for UUID - // Convert the byte slice to a uuid.UUID type - uuidValue, err := uuid.FromBytes(v) - if err == nil { - row[col] = []byte(uuidValue.String()) - } else { - row[col] = v - } - } else { - row[col] = v - } + row[col] = v } if err := stream.Send(&mgmtv1alpha1.GetConnectionDataStreamResponse{Row: row}); err != nil { @@ -1049,14 +1024,14 @@ func (s *Service) getConnectionSchema(ctx context.Context, connection *mgmtv1alp func (s *Service) getConnectionTableSchema(ctx context.Context, connection *mgmtv1alpha1.Connection, schema, table string, logger *slog.Logger) ([]*mgmtv1alpha1.DatabaseColumn, error) { conntimeout := uint32(5) - switch cconfig := connection.ConnectionConfig.Config.(type) { + switch connection.GetConnectionConfig().Config.(type) { case *mgmtv1alpha1.ConnectionConfig_PgConfig: - conn, err := s.sqlConnector.NewPgPoolFromConnectionConfig(cconfig.PgConfig, &conntimeout, logger) + conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), &conntimeout, logger) if err != nil { return nil, err } defer conn.Close() - db, err := conn.Open(ctx) + db, err := conn.Open() if err != nil { return nil, err } @@ -1077,7 +1052,7 @@ func (s *Service) getConnectionTableSchema(ctx context.Context, connection *mgmt } return schemas, nil case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.ConnectionConfig, &conntimeout, logger) + conn, err := s.sqlConnector.NewDbFromConnectionConfig(connection.GetConnectionConfig(), &conntimeout, logger) if err != nil { return nil, err } diff --git a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go index 4ec4dfd3c9..b30a20364f 100644 --- a/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go +++ b/backend/services/mgmt/v1alpha1/connection-data-service/connection-data_test.go @@ -628,7 +628,6 @@ type serviceMocks struct { SqlMock sqlmock.Sqlmock SqlDbMock *sql.DB SqlDbContainerMock *sqlconnect.MockSqlDbContainer - PgPoolContainerMock *sqlconnect.MockPgPoolContainer PgQueierMock *pg_queries.MockQuerier MysqlQueierMock *mysql_queries.MockQuerier SqlConnectorMock *sqlconnect.MockSqlConnector @@ -671,7 +670,6 @@ func createServiceMock(t *testing.T) *serviceMocks { SqlMock: sqlMock, SqlDbMock: sqlDbMock, SqlDbContainerMock: sqlconnect.NewMockSqlDbContainer(t), - PgPoolContainerMock: sqlconnect.NewMockPgPoolContainer(t), PgQueierMock: mockPgquerier, MysqlQueierMock: mockMysqlquerier, SqlConnectorMock: mockSqlConnector, diff --git a/backend/services/mgmt/v1alpha1/connection-service/connection_test.go b/backend/services/mgmt/v1alpha1/connection-service/connection_test.go deleted file mode 100644 index 850570e3be..0000000000 --- a/backend/services/mgmt/v1alpha1/connection-service/connection_test.go +++ /dev/null @@ -1,855 +0,0 @@ -package v1alpha1_connectionservice - -import ( - "context" - "database/sql" - "errors" - "testing" - "time" - - "connectrpc.com/connect" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" - db_queries "github.com/nucleuscloud/neosync/backend/gen/go/db" - mysql_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/mysql" - pg_queries "github.com/nucleuscloud/neosync/backend/gen/go/db/dbschemas/postgresql" - mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" - awsmanager "github.com/nucleuscloud/neosync/internal/aws" - - "github.com/nucleuscloud/neosync/backend/internal/apikey" - auth_apikey "github.com/nucleuscloud/neosync/backend/internal/auth/apikey" - "github.com/nucleuscloud/neosync/backend/internal/neosyncdb" - "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect" - mssql_queries "github.com/nucleuscloud/neosync/backend/pkg/mssql-querier" - "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" - pg_models "github.com/nucleuscloud/neosync/backend/sql/postgresql/models" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" -) - -const ( - anonymousUserId = "00000000-0000-0000-0000-000000000000" - mockAuthProvider = "test-provider" - mockUserId = "d5e29f1f-b920-458c-8b86-f3a180e06d98" - mockAccountId = "5629813e-1a35-4874-922c-9827d85f0378" - mockConnectionName = "test-conn" - mockConnectionId = "884765c6-1708-488d-b03a-70a02b12c81e" -) - -type ConnTypeMock string - -const ( - MysqlMock ConnTypeMock = "mysql" - PostgresMock ConnTypeMock = "postgres" -) - -// CheckConnectionConfig -func Test_CheckConnectionConfig_Postgres(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - pool, _ := pgxpool.New(context.Background(), "") - m.PgPoolContainerMock.On("Open", mock.Anything).Return(pool, nil) - m.PgPoolContainerMock.On("Close") - m.SqlConnectorMock.On("NewPgPoolFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.PgPoolContainerMock, nil) - - m.PgQuerierMock.On("GetPostgresRolePermissions", mock.Anything, mock.Anything, mock.Anything). - Return([]*pg_queries.GetPostgresRolePermissionsRow{ - { - TableSchema: "Users", - TableName: "Users", - PrivilegeType: "Insert", - }, - { - TableSchema: "Users", - TableName: "Users", - PrivilegeType: "Delete", - }, - }, nil) - - resp, err := m.Service.CheckConnectionConfig(context.Background(), &connect.Request[mgmtv1alpha1.CheckConnectionConfigRequest]{ - Msg: &mgmtv1alpha1.CheckConnectionConfigRequest{ - ConnectionConfig: getPostgresConfigMock(), - }, - }) - - assert.Nil(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 2, len(resp.Msg.Privileges[0].PrivilegeType), "There should be two privilege types for this connection") -} - -func Test_CheckConnectionConfig_Mysql(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - m.SqlDbContainerMock.On("Open").Return(m.SqlDbMock, nil) - m.SqlDbContainerMock.On("Close").Return(nil) - m.SqlConnectorMock.On("NewDbFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.SqlDbContainerMock, nil) - m.MysqlQuerierMock.On("GetMysqlRolePermissions", mock.Anything, mock.Anything, mock.Anything). - Return([]*mysql_queries.GetMysqlRolePermissionsRow{ - { - TableSchema: "Users", - TableName: "Users", - PrivilegeType: "Insert", - }, - { - TableSchema: "Users", - TableName: "Users", - PrivilegeType: "Delete", - }, - }, nil) - - resp, err := m.Service.CheckConnectionConfig(context.Background(), &connect.Request[mgmtv1alpha1.CheckConnectionConfigRequest]{ - Msg: &mgmtv1alpha1.CheckConnectionConfigRequest{ - ConnectionConfig: getMysqlConfigMock(), - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, true, resp.Msg.IsConnected) - assert.Nil(t, resp.Msg.ConnectionError) - if err := m.SqlMock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_CheckConnectionConfigs_Fail(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - m.SqlDbContainerMock.On("Open").Return(m.SqlDbMock, nil) - m.SqlDbContainerMock.On("Close").Return(nil) - m.SqlConnectorMock.On("NewDbFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.SqlDbContainerMock, nil) - m.MysqlQuerierMock.On("GetMysqlRolePermissions", mock.Anything, mock.Anything, mock.Anything). - Return([]*mysql_queries.GetMysqlRolePermissionsRow{}, errors.New("connection failed")) - - resp, err := m.Service.CheckConnectionConfig(context.Background(), &connect.Request[mgmtv1alpha1.CheckConnectionConfigRequest]{ - Msg: &mgmtv1alpha1.CheckConnectionConfigRequest{ - ConnectionConfig: getMysqlConfigMock(), - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, false, resp.Msg.IsConnected) - assert.NotNil(t, resp.Msg.ConnectionError) - if err := m.SqlMock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_CheckConnectionConfig_NotImplemented(t *testing.T) { - m := createServiceMock(t) - - resp, err := m.Service.CheckConnectionConfig(context.Background(), &connect.Request[mgmtv1alpha1.CheckConnectionConfigRequest]{ - Msg: &mgmtv1alpha1.CheckConnectionConfigRequest{ - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{}, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) -} - -// IsConnectionNameAvailable -func Test_IsConnectionNameAvailable_True(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("IsConnectionNameAvailable", context.Background(), mock.Anything, db_queries.IsConnectionNameAvailableParams{ - AccountId: accountUuid, - ConnectionName: mockConnectionName, - }).Return(int64(0), nil) - - resp, err := m.Service.IsConnectionNameAvailable(context.Background(), &connect.Request[mgmtv1alpha1.IsConnectionNameAvailableRequest]{ - Msg: &mgmtv1alpha1.IsConnectionNameAvailableRequest{ - AccountId: mockAccountId, - ConnectionName: mockConnectionName, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, true, resp.Msg.IsAvailable) -} - -func Test_IsConnectionNameAvailable_False(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("IsConnectionNameAvailable", context.Background(), mock.Anything, db_queries.IsConnectionNameAvailableParams{ - AccountId: accountUuid, - ConnectionName: mockConnectionName, - }).Return(int64(1), nil) - - resp, err := m.Service.IsConnectionNameAvailable(context.Background(), &connect.Request[mgmtv1alpha1.IsConnectionNameAvailableRequest]{ - Msg: &mgmtv1alpha1.IsConnectionNameAvailableRequest{ - AccountId: mockAccountId, - ConnectionName: mockConnectionName, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, false, resp.Msg.IsAvailable) -} - -// GetConnections -func Test_GetConnections(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connections := []db_queries.NeosyncApiConnection{getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock)} - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionsByAccount", context.Background(), mock.Anything, accountUuid).Return(connections, nil) - - resp, err := m.Service.GetConnections(context.Background(), &connect.Request[mgmtv1alpha1.GetConnectionsRequest]{ - Msg: &mgmtv1alpha1.GetConnectionsRequest{ - AccountId: mockAccountId, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 1, len(resp.Msg.GetConnections())) - assert.Equal(t, mockConnectionId, resp.Msg.Connections[0].Id) -} - -func Test_GetConnections_Error(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - var nilConnections []db_queries.NeosyncApiConnection - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionsByAccount", context.Background(), mock.Anything, accountUuid).Return(nilConnections, sql.ErrNoRows) - - resp, err := m.Service.GetConnections(context.Background(), &connect.Request[mgmtv1alpha1.GetConnectionsRequest]{ - Msg: &mgmtv1alpha1.GetConnectionsRequest{ - AccountId: mockAccountId, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) -} - -// GetConnection -func Test_GetConnection(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - - resp, err := m.Service.GetConnection(context.Background(), &connect.Request[mgmtv1alpha1.GetConnectionRequest]{ - Msg: &mgmtv1alpha1.GetConnectionRequest{ - Id: mockConnectionId, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, mockAccountId, resp.Msg.Connection.AccountId) - assert.Equal(t, mockConnectionId, resp.Msg.Connection.Id) -} - -func Test_GetConnection_Supports_WorkerApiKeys(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - ctx := context.WithValue(context.Background(), auth_apikey.TokenContextKey{}, &auth_apikey.TokenContextData{ - ApiKeyType: apikey.WorkerApiKey, - }) - m.QuerierMock.On("GetConnectionById", ctx, mock.Anything, connectionUuid).Return(connection, nil) - - resp, err := m.Service.GetConnection(ctx, &connect.Request[mgmtv1alpha1.GetConnectionRequest]{ - Msg: &mgmtv1alpha1.GetConnectionRequest{ - Id: mockConnectionId, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, mockAccountId, resp.Msg.Connection.AccountId) - assert.Equal(t, mockConnectionId, resp.Msg.Connection.Id) -} - -func Test_GetConnection_Error(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - var nilConnection db_queries.NeosyncApiConnection - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(nilConnection, sql.ErrNoRows) - - _, err := m.Service.GetConnection(context.Background(), &connect.Request[mgmtv1alpha1.GetConnectionRequest]{ - Msg: &mgmtv1alpha1.GetConnectionRequest{ - Id: mockConnectionId, - }, - }) - - assert.Error(t, err) -} - -// CreateConnection -func Test_CreateConnection(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - userUuid, _ := neosyncdb.ToUuid(mockUserId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockMgmtConnConfig := getPostgresConfigMock() - mockConnectionConfig := &pg_models.ConnectionConfig{} - _ = mockConnectionConfig.FromDto(mockMgmtConnConfig) - mockUserAccountCalls(m.UserAccountServiceMock, true) - m.QuerierMock.On("CreateConnection", context.Background(), mock.Anything, db_queries.CreateConnectionParams{ - AccountID: accountUuid, - Name: mockConnectionName, - ConnectionConfig: mockConnectionConfig, - CreatedByID: userUuid, - UpdatedByID: userUuid, - }).Return(connection, nil) - - resp, err := m.Service.CreateConnection(context.Background(), &connect.Request[mgmtv1alpha1.CreateConnectionRequest]{ - Msg: &mgmtv1alpha1.CreateConnectionRequest{ - AccountId: mockAccountId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, mockAccountId, resp.Msg.Connection.AccountId) - assert.Equal(t, mockConnectionName, resp.Msg.Connection.Name) -} - -func Test_CreateConnection_Error(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - accountUuid, _ := neosyncdb.ToUuid(mockAccountId) - userUuid, _ := neosyncdb.ToUuid(mockUserId) - mockMgmtConnConfig := getPostgresConfigMock() - mockConnectionConfig := &pg_models.ConnectionConfig{} - _ = mockConnectionConfig.FromDto(mockMgmtConnConfig) - mockIsUserInAccount(m.UserAccountServiceMock, true) - - var nilConnection db_queries.NeosyncApiConnection - m.UserAccountServiceMock.On("GetUser", mock.Anything, mock.Anything).Return(connect.NewResponse(&mgmtv1alpha1.GetUserResponse{ - UserId: mockUserId, - }), nil) - m.QuerierMock.On("CreateConnection", context.Background(), mock.Anything, db_queries.CreateConnectionParams{ - AccountID: accountUuid, - Name: mockConnectionName, - ConnectionConfig: mockConnectionConfig, - CreatedByID: userUuid, - UpdatedByID: userUuid, - }).Return(nilConnection, errors.New("help")) - - resp, err := m.Service.CreateConnection(context.Background(), &connect.Request[mgmtv1alpha1.CreateConnectionRequest]{ - Msg: &mgmtv1alpha1.CreateConnectionRequest{ - AccountId: mockAccountId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) -} - -// UpdateConnection -func Test_UpdateConnection(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - userUuid, _ := neosyncdb.ToUuid(mockUserId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockMgmtConnConfig := getPostgresConfigMock() - mockConnectionConfig := &pg_models.ConnectionConfig{} - _ = mockConnectionConfig.FromDto(mockMgmtConnConfig) - mockUserAccountCalls(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - m.QuerierMock.On("UpdateConnection", context.Background(), mock.Anything, db_queries.UpdateConnectionParams{ - ID: connectionUuid, - Name: mockConnectionName, - ConnectionConfig: mockConnectionConfig, - UpdatedByID: userUuid, - }).Return(connection, nil) - - resp, err := m.Service.UpdateConnection(context.Background(), &connect.Request[mgmtv1alpha1.UpdateConnectionRequest]{ - Msg: &mgmtv1alpha1.UpdateConnectionRequest{ - Id: mockConnectionId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, mockConnectionId, resp.Msg.Connection.Id) -} - -func Test_UpdateConnection_UpdateError(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - userUuid, _ := neosyncdb.ToUuid(mockUserId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockMgmtConnConfig := getPostgresConfigMock() - mockConnectionConfig := &pg_models.ConnectionConfig{} - _ = mockConnectionConfig.FromDto(mockMgmtConnConfig) - mockUserAccountCalls(m.UserAccountServiceMock, true) - var nilConnection db_queries.NeosyncApiConnection - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - m.QuerierMock.On("UpdateConnection", context.Background(), mock.Anything, db_queries.UpdateConnectionParams{ - ID: connectionUuid, - ConnectionConfig: mockConnectionConfig, - Name: mockConnectionName, - UpdatedByID: userUuid, - }).Return(nilConnection, errors.New("boo")) - - resp, err := m.Service.UpdateConnection(context.Background(), &connect.Request[mgmtv1alpha1.UpdateConnectionRequest]{ - Msg: &mgmtv1alpha1.UpdateConnectionRequest{ - Id: mockConnectionId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) -} - -func Test_UpdateConnection_GetConnectionError(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - mockMgmtConnConfig := getPostgresConfigMock() - - var nilConnection db_queries.NeosyncApiConnection - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(nilConnection, sql.ErrNoRows) - - resp, err := m.Service.UpdateConnection(context.Background(), &connect.Request[mgmtv1alpha1.UpdateConnectionRequest]{ - Msg: &mgmtv1alpha1.UpdateConnectionRequest{ - Id: mockConnectionId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - m.QuerierMock.AssertNotCalled(t, "UpdateConnection", mock.Anything, mock.Anything, mock.Anything) - assert.Error(t, err) - assert.Nil(t, resp) -} - -func Test_UpdateConnection_UnverifiedUser(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockMgmtConnConfig := getPostgresConfigMock() - mockConnectionConfig := &pg_models.ConnectionConfig{} - _ = mockConnectionConfig.FromDto(mockMgmtConnConfig) - mockIsUserInAccount(m.UserAccountServiceMock, false) - - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - - resp, err := m.Service.UpdateConnection(context.Background(), &connect.Request[mgmtv1alpha1.UpdateConnectionRequest]{ - Msg: &mgmtv1alpha1.UpdateConnectionRequest{ - Id: mockConnectionId, - Name: mockConnectionName, - ConnectionConfig: mockMgmtConnConfig, - }, - }) - - m.QuerierMock.AssertNotCalled(t, "UpdateConnection", mock.Anything, mock.Anything, mock.Anything) - assert.Error(t, err) - assert.Nil(t, resp) -} - -// DeleteConnection -func Test_DeleteConnection(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockIsUserInAccount(m.UserAccountServiceMock, true) - - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - m.QuerierMock.On("RemoveConnectionById", context.Background(), mock.Anything, connectionUuid).Return(nil) - - resp, err := m.Service.DeleteConnection(context.Background(), &connect.Request[mgmtv1alpha1.DeleteConnectionRequest]{ - Msg: &mgmtv1alpha1.DeleteConnectionRequest{ - Id: mockConnectionId, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func Test_DeleteConnection_NotFound(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - var nilConnection db_queries.NeosyncApiConnection - - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(nilConnection, sql.ErrNoRows) - - resp, err := m.Service.DeleteConnection(context.Background(), &connect.Request[mgmtv1alpha1.DeleteConnectionRequest]{ - Msg: &mgmtv1alpha1.DeleteConnectionRequest{ - Id: mockConnectionId, - }, - }) - - m.QuerierMock.AssertNotCalled(t, "RemoveConnectionById", context.Background(), mock.Anything, mock.Anything) - assert.NoError(t, err) - assert.NotNil(t, resp) -} - -func Test_DeleteConnection_RemoveError(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockIsUserInAccount(m.UserAccountServiceMock, true) - - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - m.QuerierMock.On("RemoveConnectionById", context.Background(), mock.Anything, connectionUuid).Return(errors.New("sad")) - - resp, err := m.Service.DeleteConnection(context.Background(), &connect.Request[mgmtv1alpha1.DeleteConnectionRequest]{ - Msg: &mgmtv1alpha1.DeleteConnectionRequest{ - Id: mockConnectionId, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) -} - -func Test_DeleteConnection_UnverifiedUserError(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - connection := getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock) - mockIsUserInAccount(m.UserAccountServiceMock, false) - - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(connection, nil) - - resp, err := m.Service.DeleteConnection(context.Background(), &connect.Request[mgmtv1alpha1.DeleteConnectionRequest]{ - Msg: &mgmtv1alpha1.DeleteConnectionRequest{ - Id: mockConnectionId, - }, - }) - - m.QuerierMock.AssertNotCalled(t, "RemoveConnectionById") - assert.Error(t, err) - assert.Nil(t, resp) -} - -// CheckSqlQuery -func Test_CheckSqlQuery_Valid(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock), nil) - - mockQuery := "some query" - m.SqlDbContainerMock.On("Open").Return(m.SqlDbMock, nil) - m.SqlDbContainerMock.On("Close").Return(nil) - m.SqlConnectorMock.On("NewDbFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.SqlDbContainerMock, nil) - m.SqlMock.ExpectBegin() - m.SqlMock.ExpectPrepare(mockQuery) - - resp, err := m.Service.CheckSqlQuery(context.Background(), &connect.Request[mgmtv1alpha1.CheckSqlQueryRequest]{ - Msg: &mgmtv1alpha1.CheckSqlQueryRequest{ - Id: mockConnectionId, - Query: mockQuery, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, true, resp.Msg.IsValid) - assert.Nil(t, resp.Msg.ErorrMessage) - if err := m.SqlMock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_CheckSqlQuery_Invalid(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock), nil) - - mockQuery := "another query" - m.SqlDbContainerMock.On("Open").Return(m.SqlDbMock, nil) - m.SqlDbContainerMock.On("Close").Return(nil) - m.SqlConnectorMock.On("NewDbFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.SqlDbContainerMock, nil) - m.SqlMock.ExpectBegin() - m.SqlMock.ExpectPrepare(mockQuery).WillReturnError(errors.New("error")) - m.SqlMock.ExpectRollback() - - resp, err := m.Service.CheckSqlQuery(context.Background(), &connect.Request[mgmtv1alpha1.CheckSqlQueryRequest]{ - Msg: &mgmtv1alpha1.CheckSqlQueryRequest{ - Id: mockConnectionId, - Query: mockQuery, - }, - }) - - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, false, resp.Msg.IsValid) - assert.NotNil(t, resp.Msg.ErorrMessage) - if err := m.SqlMock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -func Test_CheckSqlQuery_Error(t *testing.T) { - m := createServiceMock(t) - defer m.SqlDbMock.Close() - - connectionUuid, _ := neosyncdb.ToUuid(mockConnectionId) - mockIsUserInAccount(m.UserAccountServiceMock, true) - m.QuerierMock.On("GetConnectionById", context.Background(), mock.Anything, connectionUuid).Return(getConnectionMock(mockAccountId, mockConnectionName, connectionUuid, PostgresMock), nil) - - mockQuery := "diff query" - m.SqlDbContainerMock.On("Open").Return(m.SqlDbMock, nil) - m.SqlDbContainerMock.On("Close").Return(nil) - m.SqlConnectorMock.On("NewDbFromConnectionConfig", mock.Anything, mock.Anything, mock.Anything).Return(m.SqlDbContainerMock, nil) - m.SqlMock.ExpectBegin().WillReturnError(errors.New("error")) - - resp, err := m.Service.CheckSqlQuery(context.Background(), &connect.Request[mgmtv1alpha1.CheckSqlQueryRequest]{ - Msg: &mgmtv1alpha1.CheckSqlQueryRequest{ - Id: mockConnectionId, - Query: mockQuery, - }, - }) - - assert.Error(t, err) - assert.Nil(t, resp) - if err := m.SqlMock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - } -} - -type serviceMocks struct { - Service *Service - DbtxMock *neosyncdb.MockDBTX - QuerierMock *db_queries.MockQuerier - UserAccountServiceMock *mgmtv1alpha1connect.MockUserAccountServiceClient - SqlConnectorMock *sqlconnect.MockSqlConnector - SqlMock sqlmock.Sqlmock - SqlDbMock *sql.DB - SqlDbContainerMock *sqlconnect.MockSqlDbContainer - PgPoolContainerMock *sqlconnect.MockPgPoolContainer - PgQuerierMock *pg_queries.MockQuerier - MssqlQuerierMock *mssql_queries.MockQuerier - MysqlQuerierMock *mysql_queries.MockQuerier - MongoConnectorMock *mongoconnect.MockInterface -} - -func createServiceMock(t *testing.T) *serviceMocks { - mockDbtx := neosyncdb.NewMockDBTX(t) - mockQuerier := db_queries.NewMockQuerier(t) - mockUserAccountService := mgmtv1alpha1connect.NewMockUserAccountServiceClient(t) - mockSqlConnector := sqlconnect.NewMockSqlConnector(t) - mockPgquerier := pg_queries.NewMockQuerier(t) - mockMysqlquerier := mysql_queries.NewMockQuerier(t) - mockMssqlQuerier := mssql_queries.NewMockQuerier(t) - mockMongoConnector := mongoconnect.NewMockInterface(t) - mockAwsManager := awsmanager.NewMockNeosyncAwsManagerClient(t) - - sqlDbMock, sqlMock, err := sqlmock.New(sqlmock.MonitorPingsOption(true)) - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - service := New(&Config{}, neosyncdb.New(mockDbtx, mockQuerier), - mockUserAccountService, mockSqlConnector, mockPgquerier, mockMysqlquerier, mockMssqlQuerier, mockMongoConnector, mockAwsManager) - - return &serviceMocks{ - Service: service, - DbtxMock: mockDbtx, - QuerierMock: mockQuerier, - UserAccountServiceMock: mockUserAccountService, - SqlConnectorMock: mockSqlConnector, - SqlMock: sqlMock, - SqlDbMock: sqlDbMock, - SqlDbContainerMock: sqlconnect.NewMockSqlDbContainer(t), - PgPoolContainerMock: sqlconnect.NewMockPgPoolContainer(t), - PgQuerierMock: mockPgquerier, - MysqlQuerierMock: mockMysqlquerier, - MongoConnectorMock: mockMongoConnector, - MssqlQuerierMock: mockMssqlQuerier, - } -} - -func mockIsUserInAccount(userAccountServiceMock *mgmtv1alpha1connect.MockUserAccountServiceClient, isInAccount bool) { - userAccountServiceMock.On("IsUserInAccount", mock.Anything, mock.Anything).Return(connect.NewResponse(&mgmtv1alpha1.IsUserInAccountResponse{ - Ok: isInAccount, - }), nil) -} - -func mockUserAccountCalls(userAccountServiceMock *mgmtv1alpha1connect.MockUserAccountServiceClient, isInAccount bool) { - mockIsUserInAccount(userAccountServiceMock, isInAccount) - userAccountServiceMock.On("GetUser", mock.Anything, mock.Anything).Return(connect.NewResponse(&mgmtv1alpha1.GetUserResponse{ - UserId: mockUserId, - }), nil) -} - -//nolint:all -func getConnectionMock(accountId, name string, id pgtype.UUID, connType ConnTypeMock) db_queries.NeosyncApiConnection { - accountUuid, _ := neosyncdb.ToUuid(accountId) - userUuid, _ := neosyncdb.ToUuid(mockUserId) - timestamp := pgtype.Timestamp{ - Time: time.Now(), - } - if connType == MysqlMock { - return db_queries.NeosyncApiConnection{ - AccountID: accountUuid, - Name: name, - ID: id, - CreatedByID: userUuid, - UpdatedByID: userUuid, - CreatedAt: timestamp, - UpdatedAt: timestamp, - ConnectionConfig: &pg_models.ConnectionConfig{ - MysqlConfig: &pg_models.MysqlConnectionConfig{ - Connection: &pg_models.MysqlConnection{ - Host: "host", - Port: 5432, - Name: "database", - User: "user", - Pass: "topsecret", - Protocol: "tcp", - }, - }, - }, - } - } - sslMode := "disable" - return db_queries.NeosyncApiConnection{ - AccountID: accountUuid, - Name: name, - ID: id, - CreatedByID: userUuid, - UpdatedByID: userUuid, - CreatedAt: timestamp, - UpdatedAt: timestamp, - ConnectionConfig: &pg_models.ConnectionConfig{ - PgConfig: &pg_models.PostgresConnectionConfig{ - Connection: &pg_models.PostgresConnection{ - Host: "host", - Port: 5432, - Name: "database", - User: "user", - Pass: "topsecret", - SslMode: &sslMode, - }, - }, - }, - } -} - -func getPostgresConfigMock() *mgmtv1alpha1.ConnectionConfig { - return &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: getPostgresConnectionMock(), - }, - }, - }, - } -} - -func getMysqlConfigMock() *mgmtv1alpha1.ConnectionConfig { - return &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ - Connection: getMysqlConnectionMock(), - }, - }, - }, - } -} - -func getMysqlConnectionMock() *mgmtv1alpha1.MysqlConnection { - return &mgmtv1alpha1.MysqlConnection{ - Host: "host", - Port: 3306, - Name: "database", - User: "user", - Pass: "topsecret", - Protocol: "tcp", - } -} - -func getPostgresConnectionMock() *mgmtv1alpha1.PostgresConnection { - sslMode := "disable" - return &mgmtv1alpha1.PostgresConnection{ - Host: "host", - Port: 5432, - Name: "database", - User: "user", - Pass: "topsecret", - SslMode: &sslMode, - } -} - -// func getMysqlConfigMock() *mgmtv1alpha1.ConnectionConfig { -// return &mgmtv1alpha1.ConnectionConfig{ -// Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ -// MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ -// ConnectionConfig: &mgmtv1alpha1.MysqlConnectionConfig_Connection{ -// Connection: getMysqlConnectionMock(), -// }, -// }, -// }, -// } -// } - -// func getMysqlConnectionMock() *mgmtv1alpha1.MysqlConnection { -// return &mgmtv1alpha1.MysqlConnection{ -// Host: "host", -// Port: 5432, -// Name: "database", -// User: "user", -// Pass: "topsecret", -// Protocol: "tcp", -// } -// } diff --git a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go index 734ec2f111..081339d92f 100644 --- a/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go +++ b/backend/services/mgmt/v1alpha1/integration_tests/connection-service_integration_test.go @@ -1,6 +1,8 @@ package integrationtests_test import ( + "testing" + "connectrpc.com/connect" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/stretchr/testify/require" @@ -34,3 +36,150 @@ func (s *IntegrationTestSuite) Test_ConnectionService_IsConnectionNameAvailable_ requireNoErrResp(s.T(), resp, err) require.False(s.T(), resp.Msg.GetIsAvailable()) } + +func (s *IntegrationTestSuite) Test_ConnectionService_CheckConnectionConfig() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + + conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + t.Run("valid-pg-connstr", func(t *testing.T) { + t.Parallel() + + resp, err := s.unauthdClients.connections.CheckConnectionConfig( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.CheckConnectionConfigRequest{ + ConnectionConfig: conn.GetConnectionConfig(), + }), + ) + requireNoErrResp(t, resp, err) + require.True(t, resp.Msg.GetIsConnected()) + require.Empty(t, resp.Msg.GetConnectionError()) + }) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_CreateConnection() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + + t.Run("postgres-success", func(t *testing.T) { + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + }) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_UpdateConnection() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + + t.Run("postgres-success", func(t *testing.T) { + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + resp, err := s.unauthdClients.connections.UpdateConnection( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.UpdateConnectionRequest{ + Id: conn.GetId(), + Name: "foo2", + ConnectionConfig: conn.GetConnectionConfig(), + }), + ) + requireNoErrResp(t, resp, err) + require.Equal(t, "foo2", resp.Msg.GetConnection().GetName()) + }) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_GetConnection() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + + conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + resp, err := s.unauthdClients.connections.GetConnection( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionRequest{ + Id: conn.GetId(), + }), + ) + requireNoErrResp(t, resp, err) + require.NotNil(t, resp.Msg.GetConnection()) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_GetConnections() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + + s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + resp, err := s.unauthdClients.connections.GetConnections( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{ + AccountId: accountId, + }), + ) + requireNoErrResp(t, resp, err) + require.NotEmpty(t, resp.Msg.GetConnections()) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_DeleteConnection() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + + conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + resp, err := s.unauthdClients.connections.GetConnections( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.GetConnectionsRequest{ + AccountId: accountId, + }), + ) + requireNoErrResp(t, resp, err) + require.NotEmpty(t, resp.Msg.GetConnections()) + + resp2, err := s.unauthdClients.connections.DeleteConnection( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{ + Id: conn.GetId(), + }), + ) + requireNoErrResp(t, resp2, err) + + // again to test idempotency + resp2, err = s.unauthdClients.connections.DeleteConnection( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.DeleteConnectionRequest{ + Id: conn.GetId(), + }), + ) + requireNoErrResp(t, resp2, err) +} + +func (s *IntegrationTestSuite) Test_ConnectionService_CheckSqlQuery() { + t := s.T() + accountId := s.createPersonalAccount(s.ctx, s.unauthdClients.users) + pgconnstr, err := s.pgcontainer.ConnectionString(s.ctx, "sslmode=disable") + require.NoError(t, err) + + conn := s.createPostgresConnection(s.unauthdClients.connections, accountId, "foo", pgconnstr) + + resp, err := s.unauthdClients.connections.CheckSqlQuery( + s.ctx, + connect.NewRequest(&mgmtv1alpha1.CheckSqlQueryRequest{ + Id: conn.GetId(), + Query: "SELECT 1", + }), + ) + requireNoErrResp(t, resp, err) + require.True(t, resp.Msg.GetIsValid()) + require.Empty(t, resp.Msg.GetErorrMessage()) +} diff --git a/backend/sqlc.yaml b/backend/sqlc.yaml index a58496ddea..87185b8931 100644 --- a/backend/sqlc.yaml +++ b/backend/sqlc.yaml @@ -30,7 +30,7 @@ sql: package: pg_models type: VirtualForeignConstraint pointer: true - slice: true + slice: true - column: neosync_api.jobs.connection_options go_type: import: github.com/nucleuscloud/neosync/backend/sql/postgresql/models @@ -93,7 +93,6 @@ sql: go: package: "pg_queries" out: "gen/go/db/dbschemas/postgresql" - sql_package: "pgx/v5" emit_interface: true emit_methods_with_db_argument: true emit_result_struct_pointers: true diff --git a/cli/internal/cmds/neosync/sync/sync.go b/cli/internal/cmds/neosync/sync/sync.go index 3352b3308c..67ae42009b 100644 --- a/cli/internal/cmds/neosync/sync/sync.go +++ b/cli/internal/cmds/neosync/sync/sync.go @@ -42,7 +42,6 @@ import ( _ "github.com/warpstreamlabs/bento/public/components/io" _ "github.com/warpstreamlabs/bento/public/components/pure" _ "github.com/warpstreamlabs/bento/public/components/pure/extended" - _ "github.com/warpstreamlabs/bento/public/components/sql" http_client "github.com/nucleuscloud/neosync/worker/pkg/http/client" diff --git a/cli/internal/cmds/neosync/sync/ui.go b/cli/internal/cmds/neosync/sync/ui.go index b1b828a32d..79ae96fd0b 100644 --- a/cli/internal/cmds/neosync/sync/ui.go +++ b/cli/internal/cmds/neosync/sync/ui.go @@ -18,7 +18,6 @@ import ( _ "github.com/warpstreamlabs/bento/public/components/io" _ "github.com/warpstreamlabs/bento/public/components/pure" _ "github.com/warpstreamlabs/bento/public/components/pure/extended" - _ "github.com/warpstreamlabs/bento/public/components/sql" "github.com/charmbracelet/bubbles/spinner" tea "github.com/charmbracelet/bubbletea" diff --git a/go.mod b/go.mod index 4be56f2885..5fa5f012ff 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/dop251/goja_nodejs v0.0.0-20231122114759-e84d9a924c5c github.com/doug-martin/goqu/v9 v9.19.0 github.com/fatih/color v1.17.0 + github.com/gliderlabs/ssh v0.3.7 github.com/go-logr/logr v1.4.2 github.com/go-sql-driver/mysql v1.8.1 github.com/gofrs/uuid v4.4.0+incompatible @@ -117,9 +118,7 @@ require ( github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect github.com/Azure/go-autorest/logger v0.2.1 // indirect github.com/Azure/go-autorest/tracing v0.6.0 // indirect - github.com/ClickHouse/ch-go v0.61.5 // indirect github.com/ClickHouse/clickhouse-go v1.5.4 // indirect - github.com/ClickHouse/clickhouse-go/v2 v2.21.1 // indirect github.com/GoogleCloudPlatform/grpc-gcp-go/grpcgcp v1.5.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.21.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.45.0 // indirect @@ -134,6 +133,7 @@ require ( github.com/PaesslerAG/jsonpath v0.1.1 // indirect github.com/PuerkitoBio/rehttp v1.4.0 // indirect github.com/andybalholm/brotli v1.1.0 // indirect + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 // indirect github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect github.com/apache/arrow/go/v14 v14.0.2 // indirect @@ -169,12 +169,6 @@ require ( github.com/aymerick/douceur v0.2.0 // indirect github.com/benhoyt/goawk v1.25.0 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/btnguyen2k/consu/checksum v1.1.0 // indirect - github.com/btnguyen2k/consu/g18 v0.1.0 // indirect - github.com/btnguyen2k/consu/gjrc v0.2.2 // indirect - github.com/btnguyen2k/consu/olaf v0.1.3 // indirect - github.com/btnguyen2k/consu/reddo v0.1.8 // indirect - github.com/btnguyen2k/consu/semita v0.1.5 // indirect github.com/bufbuild/protocompile v0.8.0 // indirect github.com/bufbuild/protovalidate-go v0.3.0 // indirect github.com/bwmarrin/snowflake v0.3.0 // indirect @@ -197,7 +191,6 @@ require ( github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369 // indirect github.com/danieljoos/wincred v1.2.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/denisenkom/go-mssqldb v0.12.3 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/dlclark/regexp2 v1.10.0 // indirect @@ -221,8 +214,6 @@ require ( github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/generikvault/gvalstrings v0.0.0-20180926130504-471f38f0112a // indirect github.com/go-faker/faker/v4 v4.3.0 // indirect - github.com/go-faster/city v1.0.1 // indirect - github.com/go-faster/errors v0.7.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.3.0 // indirect @@ -315,7 +306,6 @@ require ( github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/microcosm-cc/bluemonday v1.0.25 // indirect - github.com/microsoft/gocosmos v1.1.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect github.com/moby/patternmatcher v0.6.0 // indirect @@ -347,7 +337,6 @@ require ( github.com/opencontainers/image-spec v1.1.0 // indirect github.com/opensearch-project/opensearch-go/v3 v3.0.0 // indirect github.com/parquet-go/parquet-go v0.23.0 // indirect - github.com/paulmach/orb v0.11.1 // indirect github.com/pborman/uuid v1.2.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect @@ -371,13 +360,11 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sahilm/fuzzy v0.1.1 // indirect - github.com/segmentio/asm v1.2.0 // indirect github.com/segmentio/encoding v0.4.0 // indirect github.com/segmentio/ksuid v1.0.4 // indirect github.com/shirou/gopsutil/v3 v3.24.2 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/shopspring/decimal v1.3.1 // indirect - github.com/sijms/go-ora/v2 v2.8.19 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/snowflakedb/gosnowflake v1.7.2 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -391,7 +378,6 @@ require ( github.com/tilinna/z85 v1.0.0 // indirect github.com/tklauser/go-sysconf v0.3.13 // indirect github.com/tklauser/numcpus v0.7.0 // indirect - github.com/trinodb/trino-go-client v0.313.0 // indirect github.com/twmb/franz-go v1.16.1 // indirect github.com/twmb/franz-go/pkg/kmsg v1.7.0 // indirect github.com/urfave/cli/v2 v2.27.1 // indirect @@ -432,10 +418,6 @@ require ( gopkg.in/go-jose/go-jose.v2 v2.6.3 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/jcmturner/aescts.v1 v1.0.1 // indirect - gopkg.in/jcmturner/dnsutils.v1 v1.0.1 // indirect - gopkg.in/jcmturner/gokrb5.v6 v6.1.1 // indirect - gopkg.in/jcmturner/rpc.v1 v1.1.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect diff --git a/go.sum b/go.sum index 969f91ea8c..ecd982c7ce 100644 --- a/go.sum +++ b/go.sum @@ -660,17 +660,14 @@ github.com/Azure/azure-sdk-for-go v68.0.0+incompatible h1:fcYLmCpyNYRnvJbPerq7U0 github.com/Azure/azure-sdk-for-go v68.0.0+incompatible/go.mod h1:9XXNKU+eRnpl9moKnB4QOLf1HestfXbmab5FXxiDBjc= github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.2 h1:Cx/DLZDK2Gaew4y+P1+CsegTonTMrwQIUz8RZvjQt3I= github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.2/go.mod h1:JVVfPiAgcVJ6HrD3A4CRryuEb5rFJAZ4nFYnUFsj6vs= -github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.6 h1:oBqQLSI1pZwGOdXJAoJJSzmff9tlfD4KroVfjQQmd0g= github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos v0.3.6/go.mod h1:Beh5cHIXJ0oWEDWk9lNFtuklCojLLQ5hl+LqSNTTs0I= github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.1.0 h1:ONYihl/vbwtVAmEmqoVDCGyhad2CIMN2kg3BO8Y5cFk= github.com/Azure/azure-sdk-for-go/sdk/data/aztables v1.1.0/go.mod h1:PMB5kQ1apg/irrvpPryVdchapVIYP+VV9iHJQ2CHwG8= -github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.5.0 h1:AifHbc4mg0x9zW52WOpKbsHaDKuRhlI7TVl47thgQ70= @@ -763,6 +760,8 @@ github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGW github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9 h1:goHVqTbFX3AIo0tzGr14pgfAW2ZfPChKO21Z9MGf/gk= github.com/antlr/antlr4/runtime/Go/antlr/v4 v4.0.0-20230512164433-5d1fd1a340c9/go.mod h1:pSwJ0fSY5KhvocuWSx4fz3BA8OrA1bQn+K1Eli3BRwM= @@ -908,13 +907,10 @@ github.com/btnguyen2k/consu/gjrc v0.2.2 h1:CAY8xPgvtWc7EMTE9gxam/BxMgTRRpc4Hs9QE github.com/btnguyen2k/consu/gjrc v0.2.2/go.mod h1:Sc0NehbI0i8V6FAY9qX1we9XXbWNnrMOb9jNpYqGBWk= github.com/btnguyen2k/consu/olaf v0.1.3 h1:0dWWmN5nOB/9pJdo7o1S3wR2+l3kG7pXHv3Vwki8uNM= github.com/btnguyen2k/consu/olaf v0.1.3/go.mod h1:6ybEnJcdcK/PNiSfkKnMoxYuKyH2vJPBvHRuuZpPvD8= -github.com/btnguyen2k/consu/reddo v0.1.7/go.mod h1:pdY5oIVX3noZIaZu3nvoKZ59+seXL/taXNGWh9xJDbg= github.com/btnguyen2k/consu/reddo v0.1.8 h1:pEAkB6eadp/q+ONy97/JkAAyj058uIgkSu8b862Fwug= github.com/btnguyen2k/consu/reddo v0.1.8/go.mod h1:pdY5oIVX3noZIaZu3nvoKZ59+seXL/taXNGWh9xJDbg= github.com/btnguyen2k/consu/semita v0.1.5 h1:fu71xNJTbCV8T+6QPJdJu3bxtmLWvTjCepkvujF74+I= github.com/btnguyen2k/consu/semita v0.1.5/go.mod h1:fksCe3L4kxiJVnKKhUXKI8mcFdB9974mtedwUVVFu1M= -github.com/btnguyen2k/consu/semver v0.2.1 h1:le0FzrM7u0IOR4MnOyBySHpZ/p3vV4JjofAhPB7edWE= -github.com/btnguyen2k/consu/semver v0.2.1/go.mod h1:jxK/nwIWTXcWlcWcfkhPfLWq9b5dVzAtJLycySBFHTc= github.com/bufbuild/protocompile v0.8.0 h1:9Kp1q6OkS9L4nM3FYbr8vlJnEwtbpDPQlQOVXfR+78s= github.com/bufbuild/protocompile v0.8.0/go.mod h1:+Etjg4guZoAqzVk2czwEQP12yaxLJ8DxuqCJ9qHdH94= github.com/bufbuild/protovalidate-go v0.3.0 h1:t9zKgM//9VtPnP0TvyFqWubLQtSbwLwEUVOxgtX9/os= @@ -1035,7 +1031,6 @@ github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91/go.mod h1:2pZnwu github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= github.com/docker/cli v25.0.3+incompatible h1:KLeNs7zws74oFuVhgZQ5ONGZiXUUdgsdy6/EsX/6284= github.com/docker/cli v25.0.3+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4= @@ -1124,6 +1119,8 @@ github.com/generikvault/gvalstrings v0.0.0-20180926130504-471f38f0112a/go.mod h1 github.com/getsentry/sentry-go v0.27.0 h1:Pv98CIbtB3LkMWmXi4Joa5OOcwbmnX88sF5qbK3r3Ps= github.com/getsentry/sentry-go v0.27.0/go.mod h1:lc76E2QywIyW8WuBnwl8Lc4bkmQH4+w1gwTf25trprY= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gliderlabs/ssh v0.3.7 h1:iV3Bqi942d9huXnzEF2Mt+CY9gLu8DNM4Obd+8bODRE= +github.com/gliderlabs/ssh v0.3.7/go.mod h1:zpHEXBstFnQYtGnB8k8kQLol82umzn/2/snG7alWVD8= github.com/go-faker/faker/v4 v4.3.0 h1:UXOW7kn/Mwd0u6MR30JjUKVzguT20EB/hBOddAAO+DY= github.com/go-faker/faker/v4 v4.3.0/go.mod h1:F/bBy8GH9NxOxMInug5Gx4WYeG6fHJZ8Ol/dhcpRub4= github.com/go-faster/city v1.0.1 h1:4WAxSZ3V2Ws4QRDrscLEDcibJY8uf41H6AhXDrNDcGw= @@ -1529,7 +1526,6 @@ github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCy github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.9.7/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= github.com/klauspost/compress v1.13.1/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= -github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= @@ -1644,8 +1640,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= -github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= -github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/montanaflynn/stats v0.7.1 h1:etflOAAHORrCC44V+aR6Ftzort912ZU+YLiSTuV8eaE= github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= @@ -1725,7 +1719,6 @@ github.com/parquet-go/parquet-go v0.23.0 h1:dyEU5oiHCtbASyItMCD2tXtT2nPmoPbKpqf0 github.com/parquet-go/parquet-go v0.23.0/go.mod h1:MnwbUcFHU6uBYMymKAlPPAw9yh3kE1wWl6Gl1uLdkNk= github.com/paulmach/orb v0.11.1 h1:3koVegMC4X/WeiXYz9iswopaTwMem53NzTJuTF20JzU= github.com/paulmach/orb v0.11.1/go.mod h1:5mULz1xQfs3bmQm63QEJA6lNGujuRafwA5S/EnuLaLU= -github.com/paulmach/protoscan v0.2.1/go.mod h1:SpcSwydNLrxUGSDvXvO0P7g7AuhJ7lcKfDlhJCDw2gY= github.com/pborman/getopt v0.0.0-20180729010549-6fdd0a2c7117/go.mod h1:85jBQOZwpVEaDAr341tbn15RS4fCAsIst0qp7i8ex1o= github.com/pborman/uuid v1.2.1 h1:+ZZIw58t/ozdjRaXh/3awHfmWRbzYxJoAdNJxe/3pvw= github.com/pborman/uuid v1.2.1/go.mod h1:X/NO0urCmaxf9VXbdlT7C2Yzkj2IKimNn4k+gtPdI/k= @@ -1745,7 +1738,6 @@ github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuR github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= @@ -1917,7 +1909,6 @@ github.com/testcontainers/testcontainers-go/modules/redis v0.33.0 h1:S/QvMOwpr00 github.com/testcontainers/testcontainers-go/modules/redis v0.33.0/go.mod h1:gudb3+6uZ9SsAysOVoLs7nazbjGlkHegBW8nqPXvDMI= github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= -github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tilinna/z85 v1.0.0 h1:uqFnJBlD01dosSeo5sK1G1YGbPuwqVHqR+12OJDRjUw= github.com/tilinna/z85 v1.0.0/go.mod h1:EfpFU/DUY4ddEy6CRvk2l+UQNEzHbh+bqBQS+04Nkxs= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= @@ -1954,10 +1945,8 @@ github.com/xanzy/go-gitlab v0.15.0 h1:rWtwKTgEnXyNUGrOArN7yyc3THRkpYcKXIXia9abyw github.com/xanzy/go-gitlab v0.15.0/go.mod h1:8zdQa/ri1dfn8eS3Ir1SyfvOKlw7WBJ8DVThkpGiXrs= github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= -github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -1978,7 +1967,6 @@ github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRT github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 h1:zzrxE1FKn5ryBNl9eKOeqQ58Y/Qpo3Q9QNxKHX5uzzQ= github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2/go.mod h1:hzfGeIUDq/j97IG+FhNqkowIyEcD88LrW6fyU3K3WqY= -github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -2006,7 +1994,6 @@ go.etcd.io/etcd/client/pkg/v3 v3.5.14 h1:SaNH6Y+rVEdxfpA2Jr5wkEvN6Zykme5+YnbCkxv go.etcd.io/etcd/client/pkg/v3 v3.5.14/go.mod h1:8uMgAokyG1czCtIdsq+AGyYQMvpIKnSvPjFMunkgeZI= go.etcd.io/etcd/client/v3 v3.5.14 h1:CWfRs4FDaDoSz81giL7zPpZH2Z35tbOrAJkkjMqOupg= go.etcd.io/etcd/client/v3 v3.5.14/go.mod h1:k3XfdV/VIHy/97rqWjoUzrj9tk7GgJGH9J8L4dNXmAk= -go.mongodb.org/mongo-driver v1.11.4/go.mod h1:PTSz5yu21bkT/wXpkS7WR5f0ddqw5quethTUn9WM+2g= go.mongodb.org/mongo-driver v1.17.1 h1:Wic5cJIwJgSpBhe3lx3+/RybR5PiYRMpVFgO7cOHyIM= go.mongodb.org/mongo-driver v1.17.1/go.mod h1:wwWm/+BuOddhcq3n68LKRmgk2wXzmF6s0SFOa0GINL4= go.nanomsg.org/mangos/v3 v3.4.2 h1:gHlopxjWvJcVCcUilQIsRQk9jdj6/HB7wrTiUN8Ki7Q= @@ -2094,7 +2081,6 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201002170205-7f63de1d35b0/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20201016220609-9e8e0b390897/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= @@ -2104,7 +2090,6 @@ golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20210817164053-32db794688a5/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.20.0/go.mod h1:Xwo95rrVNIoSMx9wa1JroENMToLWn3RNVrTBpLHgZPQ= @@ -2215,11 +2200,9 @@ golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96b golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210510120150-4163338589ed/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210520170846-37e1c6afe023/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220325170049-de3da57026de/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= @@ -2821,7 +2804,6 @@ gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hr gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo= gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM= gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q= -gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI= gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4= gopkg.in/jcmturner/gokrb5.v6 v6.1.1 h1:n0KFjpbuM5pFMN38/Ay+Br3l91netGSVqHPHEXeWUqk= gopkg.in/jcmturner/gokrb5.v6 v6.1.1/go.mod h1:NFjHNLrHQiruory+EmqDXCGv6CrjkeYeA+bR9mIfNFk= diff --git a/internal/mocks/github.com/jackc/pgx/v5/mock_Tx.go b/internal/mocks/github.com/jackc/pgx/v5/mock_Tx.go index 988afe205b..db01ac096b 100644 --- a/internal/mocks/github.com/jackc/pgx/v5/mock_Tx.go +++ b/internal/mocks/github.com/jackc/pgx/v5/mock_Tx.go @@ -235,7 +235,7 @@ func (_c *MockTx_CopyFrom_Call) RunAndReturn(run func(context.Context, pgx.Ident } // Exec provides a mock function with given fields: ctx, sql, arguments -func (_m *MockTx) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { +func (_m *MockTx) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) { var _ca []interface{} _ca = append(_ca, ctx, sql) _ca = append(_ca, arguments...) @@ -247,16 +247,16 @@ func (_m *MockTx) Exec(ctx context.Context, sql string, arguments ...interface{} var r0 pgconn.CommandTag var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) (pgconn.CommandTag, error)); ok { return rf(ctx, sql, arguments...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgconn.CommandTag); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgconn.CommandTag); ok { r0 = rf(ctx, sql, arguments...) } else { r0 = ret.Get(0).(pgconn.CommandTag) } - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, ...any) error); ok { r1 = rf(ctx, sql, arguments...) } else { r1 = ret.Error(1) @@ -273,18 +273,18 @@ type MockTx_Exec_Call struct { // Exec is a helper method to define mock.On call // - ctx context.Context // - sql string -// - arguments ...interface{} +// - arguments ...any func (_e *MockTx_Expecter) Exec(ctx interface{}, sql interface{}, arguments ...interface{}) *MockTx_Exec_Call { return &MockTx_Exec_Call{Call: _e.mock.On("Exec", append([]interface{}{ctx, sql}, arguments...)...)} } -func (_c *MockTx_Exec_Call) Run(run func(ctx context.Context, sql string, arguments ...interface{})) *MockTx_Exec_Call { +func (_c *MockTx_Exec_Call) Run(run func(ctx context.Context, sql string, arguments ...any)) *MockTx_Exec_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -297,7 +297,7 @@ func (_c *MockTx_Exec_Call) Return(commandTag pgconn.CommandTag, err error) *Moc return _c } -func (_c *MockTx_Exec_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgconn.CommandTag, error)) *MockTx_Exec_Call { +func (_c *MockTx_Exec_Call) RunAndReturn(run func(context.Context, string, ...any) (pgconn.CommandTag, error)) *MockTx_Exec_Call { _c.Call.Return(run) return _c } @@ -408,7 +408,7 @@ func (_c *MockTx_Prepare_Call) RunAndReturn(run func(context.Context, string, st } // Query provides a mock function with given fields: ctx, sql, args -func (_m *MockTx) Query(ctx context.Context, sql string, args ...interface{}) (pgx.Rows, error) { +func (_m *MockTx) Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error) { var _ca []interface{} _ca = append(_ca, ctx, sql) _ca = append(_ca, args...) @@ -420,10 +420,10 @@ func (_m *MockTx) Query(ctx context.Context, sql string, args ...interface{}) (p var r0 pgx.Rows var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) (pgx.Rows, error)); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) (pgx.Rows, error)); ok { return rf(ctx, sql, args...) } - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Rows); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgx.Rows); ok { r0 = rf(ctx, sql, args...) } else { if ret.Get(0) != nil { @@ -431,7 +431,7 @@ func (_m *MockTx) Query(ctx context.Context, sql string, args ...interface{}) (p } } - if rf, ok := ret.Get(1).(func(context.Context, string, ...interface{}) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, string, ...any) error); ok { r1 = rf(ctx, sql, args...) } else { r1 = ret.Error(1) @@ -448,18 +448,18 @@ type MockTx_Query_Call struct { // Query is a helper method to define mock.On call // - ctx context.Context // - sql string -// - args ...interface{} +// - args ...any func (_e *MockTx_Expecter) Query(ctx interface{}, sql interface{}, args ...interface{}) *MockTx_Query_Call { return &MockTx_Query_Call{Call: _e.mock.On("Query", append([]interface{}{ctx, sql}, args...)...)} } -func (_c *MockTx_Query_Call) Run(run func(ctx context.Context, sql string, args ...interface{})) *MockTx_Query_Call { +func (_c *MockTx_Query_Call) Run(run func(ctx context.Context, sql string, args ...any)) *MockTx_Query_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -472,13 +472,13 @@ func (_c *MockTx_Query_Call) Return(_a0 pgx.Rows, _a1 error) *MockTx_Query_Call return _c } -func (_c *MockTx_Query_Call) RunAndReturn(run func(context.Context, string, ...interface{}) (pgx.Rows, error)) *MockTx_Query_Call { +func (_c *MockTx_Query_Call) RunAndReturn(run func(context.Context, string, ...any) (pgx.Rows, error)) *MockTx_Query_Call { _c.Call.Return(run) return _c } // QueryRow provides a mock function with given fields: ctx, sql, args -func (_m *MockTx) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row { +func (_m *MockTx) QueryRow(ctx context.Context, sql string, args ...any) pgx.Row { var _ca []interface{} _ca = append(_ca, ctx, sql) _ca = append(_ca, args...) @@ -489,7 +489,7 @@ func (_m *MockTx) QueryRow(ctx context.Context, sql string, args ...interface{}) } var r0 pgx.Row - if rf, ok := ret.Get(0).(func(context.Context, string, ...interface{}) pgx.Row); ok { + if rf, ok := ret.Get(0).(func(context.Context, string, ...any) pgx.Row); ok { r0 = rf(ctx, sql, args...) } else { if ret.Get(0) != nil { @@ -508,18 +508,18 @@ type MockTx_QueryRow_Call struct { // QueryRow is a helper method to define mock.On call // - ctx context.Context // - sql string -// - args ...interface{} +// - args ...any func (_e *MockTx_Expecter) QueryRow(ctx interface{}, sql interface{}, args ...interface{}) *MockTx_QueryRow_Call { return &MockTx_QueryRow_Call{Call: _e.mock.On("QueryRow", append([]interface{}{ctx, sql}, args...)...)} } -func (_c *MockTx_QueryRow_Call) Run(run func(ctx context.Context, sql string, args ...interface{})) *MockTx_QueryRow_Call { +func (_c *MockTx_QueryRow_Call) Run(run func(ctx context.Context, sql string, args ...any)) *MockTx_QueryRow_Call { _c.Call.Run(func(args mock.Arguments) { - variadicArgs := make([]interface{}, len(args)-2) + variadicArgs := make([]any, len(args)-2) for i, a := range args[2:] { if a != nil { - variadicArgs[i] = a.(interface{}) + variadicArgs[i] = a.(any) } } run(args[0].(context.Context), args[1].(string), variadicArgs...) @@ -532,7 +532,7 @@ func (_c *MockTx_QueryRow_Call) Return(_a0 pgx.Row) *MockTx_QueryRow_Call { return _c } -func (_c *MockTx_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...interface{}) pgx.Row) *MockTx_QueryRow_Call { +func (_c *MockTx_QueryRow_Call) RunAndReturn(run func(context.Context, string, ...any) pgx.Row) *MockTx_QueryRow_Call { _c.Call.Return(run) return _c } diff --git a/internal/sshtunnel/connectors/mssqltunconnector/connector.go b/internal/sshtunnel/connectors/mssqltunconnector/connector.go new file mode 100644 index 0000000000..a841207de6 --- /dev/null +++ b/internal/sshtunnel/connectors/mssqltunconnector/connector.go @@ -0,0 +1,25 @@ +package mssqltunconnector + +import ( + "database/sql/driver" + + mssql "github.com/microsoft/go-mssqldb" + "github.com/nucleuscloud/neosync/internal/sshtunnel" +) + +type Connector struct { + driver.Connector +} + +var _ driver.Connector = (*Connector)(nil) + +func New(dialer sshtunnel.Dialer, dsn string) (*Connector, func(), error) { + connector, err := mssql.NewConnector(dsn) + if err != nil { + return nil, nil, err + } + + connector.Dialer = mssql.Dialer(dialer) + + return &Connector{Connector: connector}, func() {}, nil +} diff --git a/internal/sshtunnel/connectors/mssqltunconnector/connector_test.go b/internal/sshtunnel/connectors/mssqltunconnector/connector_test.go new file mode 100644 index 0000000000..f1c52bbfb9 --- /dev/null +++ b/internal/sshtunnel/connectors/mssqltunconnector/connector_test.go @@ -0,0 +1,24 @@ +package mssqltunconnector + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_New(t *testing.T) { + t.Run("valid", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "sqlserver://sa:myStr0ngP%40assword@localhost?database=master") + require.NoError(t, err) + require.NotNil(t, cleanup) + require.NotNil(t, connector) + }) + + t.Run("invalid dsn", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "sqlserver://sa:myStr0ngP%40assword@localhost:invalidport") + require.Error(t, err) + require.Nil(t, cleanup) + require.Nil(t, connector) + }) +} diff --git a/internal/sshtunnel/connectors/mysqltunconnector/connector.go b/internal/sshtunnel/connectors/mysqltunconnector/connector.go new file mode 100644 index 0000000000..45177a1657 --- /dev/null +++ b/internal/sshtunnel/connectors/mysqltunconnector/connector.go @@ -0,0 +1,54 @@ +package mysqltunconnector + +import ( + "context" + "database/sql/driver" + "fmt" + "net" + "strings" + + "github.com/go-sql-driver/mysql" + "github.com/google/uuid" + "github.com/nucleuscloud/neosync/internal/sshtunnel" +) + +type Connector struct { + driver.Connector +} + +var _ driver.Connector = (*Connector)(nil) + +func New(dialer sshtunnel.Dialer, dsn string) (*Connector, func(), error) { + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + return nil, nil, fmt.Errorf("unable to parse mysql dsn: %w", err) + } + + ogNetwork := cfg.Net + newNetwork := buildUniqueNetwork(ogNetwork) + + cfg.Net = newNetwork + mysql.RegisterDialContext(cfg.Net, func(ctx context.Context, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, ogNetwork, addr) + }) + + conn, err := mysql.NewConnector(cfg) + if err != nil { + return nil, nil, err + } + + cleanup := func() { + mysql.DeregisterDialContext(newNetwork) + } + + return &Connector{Connector: conn}, cleanup, nil +} + +func buildUniqueNetwork(network string) string { + return fmt.Sprintf("%s_%s", network, getUniqueIdentifier()) +} + +func getUniqueIdentifier() string { + id := uuid.NewString() + return strings.ReplaceAll(id, "-", "") +} diff --git a/internal/sshtunnel/connectors/mysqltunconnector/connector_test.go b/internal/sshtunnel/connectors/mysqltunconnector/connector_test.go new file mode 100644 index 0000000000..41402685ad --- /dev/null +++ b/internal/sshtunnel/connectors/mysqltunconnector/connector_test.go @@ -0,0 +1,25 @@ +package mysqltunconnector + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_New(t *testing.T) { + t.Run("valid", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "foo:bar@tcp(localhost:3306)/mydb") + require.NoError(t, err) + require.NotNil(t, cleanup) + require.NotNil(t, connector) + cleanup() + }) + + t.Run("invalid conn", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "foo:bar@tcp(localhost:3306)") + require.Error(t, err) + require.Nil(t, cleanup) + require.Nil(t, connector) + }) +} diff --git a/internal/sshtunnel/connectors/postgrestunconnector/connector.go b/internal/sshtunnel/connectors/postgrestunconnector/connector.go new file mode 100644 index 0000000000..b9066d2b23 --- /dev/null +++ b/internal/sshtunnel/connectors/postgrestunconnector/connector.go @@ -0,0 +1,46 @@ +package postgrestunconnector + +import ( + "context" + "database/sql/driver" + "net" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/stdlib" + "github.com/nucleuscloud/neosync/internal/sshtunnel" +) + +type Connector struct { + connStr string + driver driver.Driver +} + +var _ driver.Connector = (*Connector)(nil) + +func New( + dialer sshtunnel.Dialer, + dsn string, +) (*Connector, func(), error) { + cfg, err := pgx.ParseConfig(dsn) + if err != nil { + return nil, nil, err + } + cfg.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.DialContext(ctx, network, addr) + } + + connStr := stdlib.RegisterConnConfig(cfg) + cleanup := func() { + stdlib.UnregisterConnConfig(connStr) + } + + return &Connector{connStr: connStr, driver: stdlib.GetDefaultDriver()}, cleanup, nil +} + +func (c *Connector) Connect(_ context.Context) (driver.Conn, error) { + return c.driver.Open(c.connStr) +} + +func (c *Connector) Driver() driver.Driver { + return c.driver +} diff --git a/internal/sshtunnel/connectors/postgrestunconnector/connector_test.go b/internal/sshtunnel/connectors/postgrestunconnector/connector_test.go new file mode 100644 index 0000000000..efba4eba66 --- /dev/null +++ b/internal/sshtunnel/connectors/postgrestunconnector/connector_test.go @@ -0,0 +1,32 @@ +package postgrestunconnector + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func Test_New(t *testing.T) { + t.Run("valid", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "postgres://postgres:postgres@localhost:5432") + require.NoError(t, err) + require.NotNil(t, cleanup) + require.NotNil(t, connector) + cleanup() + }) + + t.Run("invalid conn", func(t *testing.T) { + connector, cleanup, err := New(&net.Dialer{}, "foo:bar@tcp(localhost:3306)") + require.Error(t, err) + require.Nil(t, cleanup) + require.Nil(t, connector) + }) +} + +func Test_Connector_Driver(t *testing.T) { + connector, _, err := New(&net.Dialer{}, "postgres://postgres:postgres@localhost:5432") + require.NoError(t, err) + driver := connector.Driver() + require.NotEmpty(t, driver) +} diff --git a/internal/sshtunnel/dialer.go b/internal/sshtunnel/dialer.go new file mode 100644 index 0000000000..39e99c1b07 --- /dev/null +++ b/internal/sshtunnel/dialer.go @@ -0,0 +1,103 @@ +package sshtunnel + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "golang.org/x/crypto/ssh" +) + +type Dialer interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) + Dial(network, addr string) (net.Conn, error) +} + +var _ Dialer = (*SSHDialer)(nil) + +type SSHDialer struct { + addr string + cfg *ssh.ClientConfig + + client *ssh.Client + clientmu *sync.RWMutex +} + +func NewLazySSHDialer(addr string, cfg *ssh.ClientConfig) *SSHDialer { + return &SSHDialer{addr: addr, cfg: cfg, clientmu: &sync.RWMutex{}} +} + +func NewSSHDialer(client *ssh.Client) *SSHDialer { + return &SSHDialer{client: client, clientmu: &sync.RWMutex{}} +} + +func (s *SSHDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + client, err := s.getClient() + if err != nil { + return nil, err + } + conn, err := client.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + return &wrappedSshConn{Conn: conn}, nil +} + +func (s *SSHDialer) Dial(network, addr string) (net.Conn, error) { + return s.DialContext(context.Background(), network, addr) +} + +func (s *SSHDialer) Close() error { + s.clientmu.Lock() + defer s.clientmu.Unlock() + if s.client == nil { + return nil + } + client := s.client + s.client = nil + return client.Close() +} + +func (s *SSHDialer) getClient() (*ssh.Client, error) { + s.clientmu.RLock() + client := s.client + s.clientmu.RUnlock() + if client != nil { + return client, nil + } + s.clientmu.Lock() + defer s.clientmu.Unlock() + if s.client != nil { + return s.client, nil + } + // todo: implement retries + client, err := ssh.Dial("tcp", s.addr, s.cfg) + if err != nil { + return nil, fmt.Errorf("unable to dial ssh server: %w", err) + } + s.client = client + return client, nil +} + +type wrappedSshConn struct { + net.Conn +} + +func (w *wrappedSshConn) SetDeadline(deadline time.Time) error { + if err := w.SetReadDeadline(deadline); err != nil { + return err + } + return w.SetWriteDeadline(deadline) +} + +// SSH net.Conn does not implement this, so we're overriding it to not return an error +func (w *wrappedSshConn) SetReadDeadline(deadline time.Time) error { + return nil +} + +// SSH net.Conn does not implement this, so we're overriding it to not return an error +func (w *wrappedSshConn) SetWriteDeadline(deadline time.Time) error { + return nil +} diff --git a/internal/sshtunnel/dialer_integration_test.go b/internal/sshtunnel/dialer_integration_test.go new file mode 100644 index 0000000000..5a947c5de9 --- /dev/null +++ b/internal/sshtunnel/dialer_integration_test.go @@ -0,0 +1,158 @@ +package sshtunnel_test + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "log/slog" + "os" + "testing" + "time" + + gssh "github.com/gliderlabs/ssh" + "github.com/nucleuscloud/neosync/internal/sshtunnel" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/mssqltunconnector" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/mysqltunconnector" + "github.com/nucleuscloud/neosync/internal/sshtunnel/connectors/postgrestunconnector" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" + + "github.com/testcontainers/testcontainers-go" + testmssql "github.com/testcontainers/testcontainers-go/modules/mssql" + testmysql "github.com/testcontainers/testcontainers-go/modules/mysql" + "github.com/testcontainers/testcontainers-go/modules/postgres" + testpg "github.com/testcontainers/testcontainers-go/modules/postgres" + + "github.com/testcontainers/testcontainers-go/wait" +) + +func Test_NewLazySSHDialer(t *testing.T) { + t.Parallel() + evkey := "INTEGRATION_TESTS_ENABLED" + shouldRun := os.Getenv(evkey) + if shouldRun != "1" { + slog.Warn(fmt.Sprintf("skipping integration tests, set %s=1 to enable", evkey)) + return + } + + ctx := context.Background() + + addr := ":2222" + server := newSshForwardServer(t, addr) + + go func() { + err := server.ListenAndServe() + if err != nil && err != gssh.ErrServerClosed { + panic(err) + } + }() + + time.Sleep(100 * time.Millisecond) + + cconfig := &ssh.ClientConfig{ + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + } + dialer := sshtunnel.NewLazySSHDialer(addr, cconfig) + defer dialer.Close() + + t.Run("postgres", func(t *testing.T) { + t.Parallel() + + container, err := testpg.Run( + ctx, + "postgres:15", + postgres.WithDatabase("postgres"), + testcontainers.WithWaitStrategy( + wait.ForLog("database system is ready to accept connections"). + WithOccurrence(2).WithStartupTimeout(20*time.Second), + ), + ) + require.NoError(t, err) + connstr, err := container.ConnectionString(ctx, "sslmode=disable") + require.NoError(t, err) + + connector, cleanup, err := postgrestunconnector.New(dialer, connstr) + require.NoError(t, err) + defer cleanup() + + requireDbConnects(t, connector) + }) + + t.Run("mysql", func(t *testing.T) { + t.Parallel() + + container, err := testmysql.Run(ctx, + "mysql:8.0.36", + testmysql.WithDatabase("mydb"), + testmysql.WithUsername("root"), + testmysql.WithPassword("password"), + testcontainers.WithWaitStrategy( + wait.ForLog("port: 3306 MySQL Community Server"). + WithOccurrence(1).WithStartupTimeout(20*time.Second), + ), + ) + require.NoError(t, err) + connstr, err := container.ConnectionString(ctx) + require.NoError(t, err) + + connector, cleanup, err := mysqltunconnector.New(dialer, connstr) + require.NoError(t, err) + defer cleanup() + + requireDbConnects(t, connector) + }) + + t.Run("mssql", func(t *testing.T) { + t.Parallel() + container, err := testmssql.Run(ctx, + "mcr.microsoft.com/mssql/server:2022-latest", + testmssql.WithAcceptEULA(), + testmssql.WithPassword("mssqlPASSword1"), + ) + require.NoError(t, err) + connstr, err := container.ConnectionString(ctx) + require.NoError(t, err) + + connector, cleanup, err := mssqltunconnector.New(dialer, connstr) + require.NoError(t, err) + defer cleanup() + + requireDbConnects(t, connector) + }) +} + +func requireDbConnects(t testing.TB, connector driver.Connector) { + db := sql.OpenDB(connector) + defer db.Close() + + err := db.Ping() + require.NoError(t, err) +} + +func newSshForwardServer(t testing.TB, addr string) *gssh.Server { + forwardHandler := &gssh.ForwardedTCPHandler{} + return &gssh.Server{ + Addr: addr, + Handler: gssh.Handler(func(s gssh.Session) { + select {} + }), + LocalPortForwardingCallback: gssh.LocalPortForwardingCallback(func(ctx gssh.Context, destinationHost string, destinationPort uint32) bool { + t.Logf("Accepted forward %s:%d\n", destinationHost, destinationPort) + return true + }), + ReversePortForwardingCallback: gssh.ReversePortForwardingCallback(func(ctx gssh.Context, destinationHost string, destinationPort uint32) bool { + t.Logf("attempt to bind %s:%d granted\n", destinationHost, destinationPort) + return true + }), + RequestHandlers: map[string]gssh.RequestHandler{ + "tcpip-forward": forwardHandler.HandleSSHRequest, + "cancel-tcpip-forward": forwardHandler.HandleSSHRequest, + }, + ChannelHandlers: map[string]gssh.ChannelHandler{ + "direct-tcpip": gssh.DirectTCPIPHandler, + "session": gssh.DefaultSessionHandler, + }, + } +} diff --git a/backend/pkg/sshtunnel/utils.go b/internal/sshtunnel/utils.go similarity index 100% rename from backend/pkg/sshtunnel/utils.go rename to internal/sshtunnel/utils.go diff --git a/backend/pkg/sshtunnel/utils_test.go b/internal/sshtunnel/utils_test.go similarity index 100% rename from backend/pkg/sshtunnel/utils_test.go rename to internal/sshtunnel/utils_test.go diff --git a/worker/internal/cmds/worker/serve/serve.go b/worker/internal/cmds/worker/serve/serve.go index 462257d8f9..ccf725766d 100644 --- a/worker/internal/cmds/worker/serve/serve.go +++ b/worker/internal/cmds/worker/serve/serve.go @@ -274,7 +274,7 @@ func serve(ctx context.Context) error { otelconfig.IsEnabled, ) disableReaper := false - syncActivity := sync_activity.New(connclient, jobclient, &sync.Map{}, temporalClient, syncActivityMeter, sync_activity.NewBenthosStreamManager(), disableReaper) + syncActivity := sync_activity.New(connclient, jobclient, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, temporalClient, syncActivityMeter, sync_activity.NewBenthosStreamManager(), disableReaper) retrieveActivityOpts := syncactivityopts_activity.New(jobclient) runSqlInitTableStatements := runsqlinittablestmts_activity.New(jobclient, connclient, sqlmanager) accountStatusActivity := accountstatus_activity.New(userclient) diff --git a/worker/internal/connection-tunnel-manager/manager.go b/worker/internal/connection-tunnel-manager/manager.go index 29f615c688..645273cc1d 100644 --- a/worker/internal/connection-tunnel-manager/manager.go +++ b/worker/internal/connection-tunnel-manager/manager.go @@ -1,29 +1,21 @@ package connectiontunnelmanager import ( - "errors" "fmt" "log/slog" "sync" "time" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/nucleuscloud/neosync/backend/pkg/sshtunnel" - "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" ) -type ConnectionProvider[T any, TConfig any] interface { - GetConnectionDetails(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (ConnectionDetails, error) - GetConnectionClient(driver string, connectionString string, opts TConfig) (T, error) - GetConnectionClientConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig) (TConfig, error) +type ConnectionProvider[T any] interface { + GetConnectionClient(connectionConfig *mgmtv1alpha1.ConnectionConfig) (T, error) CloseClientConnection(client T) error } -type ConnectionTunnelManager[T any, TConfig any] struct { - connectionProvider ConnectionProvider[T, TConfig] - - connDetailsMap map[string]ConnectionDetails - connDetailsMu sync.RWMutex +type ConnectionTunnelManager[T any] struct { + connectionProvider ConnectionProvider[T] sessionMap map[string]map[string]struct{} sessionMu sync.RWMutex @@ -36,11 +28,9 @@ type ConnectionTunnelManager[T any, TConfig any] struct { type ConnectionDetails interface { String() string - GetTunnel() *sshtunnel.Sshtunnel } type Interface[T any] interface { - GetConnectionString(session string, connection *mgmtv1alpha1.Connection, logger *slog.Logger) (string, error) GetConnection(session string, connection *mgmtv1alpha1.Connection, logger *slog.Logger) (T, error) ReleaseSession(session string) bool @@ -48,61 +38,17 @@ type Interface[T any] interface { Reaper() } -var _ Interface[any] = &ConnectionTunnelManager[any, any]{} // enforces ConnectionTunnelManager always conforms to the interface +var _ Interface[any] = &ConnectionTunnelManager[any]{} // enforces ConnectionTunnelManager always conforms to the interface -func NewConnectionTunnelManager[T any, TConfig any](connectionProvider ConnectionProvider[T, TConfig]) *ConnectionTunnelManager[T, TConfig] { - return &ConnectionTunnelManager[T, TConfig]{ +func NewConnectionTunnelManager[T any](connectionProvider ConnectionProvider[T]) *ConnectionTunnelManager[T] { + return &ConnectionTunnelManager[T]{ connectionProvider: connectionProvider, sessionMap: map[string]map[string]struct{}{}, - connDetailsMap: map[string]ConnectionDetails{}, connMap: map[string]T{}, } } -func (c *ConnectionTunnelManager[T, TConfig]) GetConnectionString( - session string, - connection *mgmtv1alpha1.Connection, - logger *slog.Logger, -) (string, error) { - c.connDetailsMu.RLock() - loadedDetails, ok := c.connDetailsMap[connection.Id] - - if ok { - c.bindSession(session, connection.Id) - c.connDetailsMu.RUnlock() - return loadedDetails.String(), nil - } - c.connDetailsMu.RUnlock() - c.connDetailsMu.Lock() - defer c.connDetailsMu.Unlock() - - loadedDetails, ok = c.connDetailsMap[connection.Id] - if ok { - c.bindSession(session, connection.Id) - return loadedDetails.String(), nil - } - - details, err := c.connectionProvider.GetConnectionDetails(connection.ConnectionConfig, shared.Ptr(uint32(5)), logger) - if err != nil { - return "", err - } - tunnel := details.GetTunnel() - if tunnel == nil { - c.bindSession(session, connection.Id) - c.connDetailsMap[connection.Id] = details - return details.String(), nil - } - ready, err := tunnel.Start(logger) - if err != nil { - return "", fmt.Errorf("unable to start ssh tunnel: %w", err) - } - <-ready // this isn't great as it will block all other requests until this tunnel is ready - c.connDetailsMap[connection.Id] = details - c.bindSession(session, connection.Id) - return details.String(), nil -} - -func (c *ConnectionTunnelManager[T, TConfig]) GetConnection( +func (c *ConnectionTunnelManager[T]) GetConnection( session string, connection *mgmtv1alpha1.Connection, logger *slog.Logger, @@ -124,24 +70,7 @@ func (c *ConnectionTunnelManager[T, TConfig]) GetConnection( return existingDb, nil } - connectionString, err := c.GetConnectionString(session, connection, logger) - if err != nil { - var result T - return result, err - } - driver, err := getDriverFromConnection(connection) - if err != nil { - var result T - return result, err - } - - connClientConfig, err := c.connectionProvider.GetConnectionClientConfig(connection.GetConnectionConfig()) - if err != nil { - var result T - return result, err - } - - connectionClient, err := c.connectionProvider.GetConnectionClient(driver, connectionString, connClientConfig) + connectionClient, err := c.connectionProvider.GetConnectionClient(connection.GetConnectionConfig()) if err != nil { var result T return result, err @@ -152,7 +81,7 @@ func (c *ConnectionTunnelManager[T, TConfig]) GetConnection( return connectionClient, nil } -func (c *ConnectionTunnelManager[T, TConfig]) ReleaseSession(session string) bool { +func (c *ConnectionTunnelManager[T]) ReleaseSession(session string) bool { c.sessionMu.RLock() connMap, ok := c.sessionMap[session] if !ok || len(connMap) == 0 { @@ -170,7 +99,7 @@ func (c *ConnectionTunnelManager[T, TConfig]) ReleaseSession(session string) boo return true } -func (c *ConnectionTunnelManager[T, TConfig]) bindSession(session, connectionId string) { +func (c *ConnectionTunnelManager[T]) bindSession(session, connectionId string) { c.sessionMu.RLock() connmap, ok := c.sessionMap[session] if ok { @@ -188,11 +117,11 @@ func (c *ConnectionTunnelManager[T, TConfig]) bindSession(session, connectionId c.sessionMap[session][connectionId] = struct{}{} } -func (c *ConnectionTunnelManager[T, TConfig]) Shutdown() { +func (c *ConnectionTunnelManager[T]) Shutdown() { c.shutdown <- struct{}{} } -func (c *ConnectionTunnelManager[T, TConfig]) Reaper() { +func (c *ConnectionTunnelManager[T]) Reaper() { for { select { case <-c.shutdown: @@ -204,9 +133,8 @@ func (c *ConnectionTunnelManager[T, TConfig]) Reaper() { } } -func (c *ConnectionTunnelManager[T, TConfig]) hardClose() { +func (c *ConnectionTunnelManager[T]) hardClose() { c.connMu.Lock() - c.connDetailsMu.Lock() c.sessionMu.Lock() for connId, dbConn := range c.connMap { err := c.connectionProvider.CloseClientConnection(dbConn) @@ -216,23 +144,14 @@ func (c *ConnectionTunnelManager[T, TConfig]) hardClose() { delete(c.connMap, connId) } - for connId, details := range c.connDetailsMap { - tunnel := details.GetTunnel() - if tunnel != nil { - tunnel.Close() - } - delete(c.connDetailsMap, connId) - } - for sessionId := range c.sessionMap { delete(c.sessionMap, sessionId) } c.connMu.Unlock() - c.connDetailsMu.Unlock() c.sessionMu.Unlock() } -func (c *ConnectionTunnelManager[T, TConfig]) close() { +func (c *ConnectionTunnelManager[T]) close() { c.connMu.Lock() c.sessionMu.Lock() sessionConnections := getUniqueConnectionIdsFromSessions(c.sessionMap) @@ -247,21 +166,6 @@ func (c *ConnectionTunnelManager[T, TConfig]) close() { } c.sessionMu.Unlock() c.connMu.Unlock() - - c.connDetailsMu.Lock() - c.sessionMu.Lock() - sessionConnections = getUniqueConnectionIdsFromSessions(c.sessionMap) - for connId, details := range c.connDetailsMap { - if _, ok := sessionConnections[connId]; !ok { - tunnel := details.GetTunnel() - if tunnel != nil { - tunnel.Close() - } - delete(c.connDetailsMap, connId) - } - } - c.sessionMu.Unlock() - c.connDetailsMu.Unlock() } func getUniqueConnectionIdsFromSessions(sessionMap map[string]map[string]struct{}) map[string]struct{} { @@ -273,20 +177,3 @@ func getUniqueConnectionIdsFromSessions(sessionMap map[string]map[string]struct{ } return connSet } - -func getDriverFromConnection(connection *mgmtv1alpha1.Connection) (string, error) { - if connection == nil { - return "", errors.New("connection was nil") - } - switch connection.GetConnectionConfig().Config.(type) { - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - return "mysql", nil - case *mgmtv1alpha1.ConnectionConfig_PgConfig: - return "pgx", nil - case *mgmtv1alpha1.ConnectionConfig_MongoConfig: - return "mongodb", nil - case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - return "sqlserver", nil - } - return "", fmt.Errorf("unsupported connection type when computing driver: %T", connection.GetConnectionConfig().Config) -} diff --git a/worker/internal/connection-tunnel-manager/manager_test.go b/worker/internal/connection-tunnel-manager/manager_test.go index d525aa1d6a..2e82e923ea 100644 --- a/worker/internal/connection-tunnel-manager/manager_test.go +++ b/worker/internal/connection-tunnel-manager/manager_test.go @@ -1,130 +1,26 @@ package connectiontunnelmanager import ( + "io" "log/slog" "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - dbconnectconfig "github.com/nucleuscloud/neosync/backend/pkg/dbconnect-config" - "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" - neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" ) -func Test_NewConnectionTunnelManager(t *testing.T) { - require.NotNil(t, NewConnectionTunnelManager[any, any](nil)) -} - -func Test_ConnectionTunnelManager_GetConnectionString(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) - mgr := NewConnectionTunnelManager(provider) - - conn := &mgmtv1alpha1.Connection{ - Id: "1", - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }, - } - - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - connstr, err := mgr.GetConnectionString("111", conn, slog.Default()) - require.NoError(t, err) - require.Equal(t, "postgres://foo:bar@localhost:5432/test", connstr) -} - -func Test_ConnectionTunnelManager_GetConnectionString_Unique_Conns(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) - mgr := NewConnectionTunnelManager(provider) - - conn1 := &mgmtv1alpha1.Connection{ - Id: "1", - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: &mgmtv1alpha1.PostgresConnection{ - Host: "1", - }, - }, - }, - }, - }, - } - conn2 := &mgmtv1alpha1.Connection{ - Id: "2", - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: &mgmtv1alpha1.PostgresConnection{ - Host: "2", - }, - }, - }, - }, - }, - } - - provider.On("GetConnectionDetails", conn1.ConnectionConfig, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - - provider.On("GetConnectionDetails", conn2.ConnectionConfig, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo2"), - }, nil) - - connstr, err := mgr.GetConnectionString("111", conn1, slog.Default()) - require.NoError(t, err) - require.Equal(t, "postgres://foo:bar@localhost:5432/test", connstr) - connstr, err = mgr.GetConnectionString("111", conn2, slog.Default()) - require.NoError(t, err) - require.Equal(t, "postgres://foo2:bar@localhost:5432/test", connstr) -} - -func Test_ConnectionTunnelManager_GetConnectionString_Parallel_Sessions_Same_Connection(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) - mgr := NewConnectionTunnelManager(provider) - - conn := &mgmtv1alpha1.Connection{ - Id: "1", - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }, - } - - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - - errgrp := errgroup.Group{} - errgrp.Go(func() error { - _, err := mgr.GetConnectionString("111", conn, slog.Default()) - return err - }) - errgrp.Go(func() error { - _, err := mgr.GetConnectionString("222", conn, slog.Default()) - return err - }) - errgrp.Go(func() error { - _, err := mgr.GetConnectionString("333", conn, slog.Default()) - return err - }) - err := errgrp.Wait() - require.NoError(t, err) +var ( + discardLogger = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) +) - provider.AssertNumberOfCalls(t, "GetConnectionDetails", 1) +func Test_NewConnectionTunnelManager(t *testing.T) { + require.NotNil(t, NewConnectionTunnelManager[any](nil)) } func Test_ConnectionTunnelManager_GetConnectionClient(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) + provider := NewMockConnectionProvider[any](t) mgr := NewConnectionTunnelManager(provider) conn := &mgmtv1alpha1.Connection{ @@ -134,57 +30,47 @@ func Test_ConnectionTunnelManager_GetConnectionClient(t *testing.T) { }, } - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - provider.On("GetConnectionClientConfig", mock.Anything).Return(struct{}{}, nil) - provider.On("GetConnectionClient", "pgx", "postgres://foo:bar@localhost:5432/test", mock.Anything).Return(neosync_benthos_sql.NewMockSqlDbtx(t), nil) + provider.On("GetConnectionClient", mock.Anything).Return(&struct{}{}, nil) - db, err := mgr.GetConnection("111", conn, slog.Default()) + db, err := mgr.GetConnection("111", conn, discardLogger) require.NoError(t, err) require.NotNil(t, db) } func Test_ConnectionTunnelManager_GetConnection_Parallel_Sessions_Same_Connection(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) + provider := NewMockConnectionProvider[any](t) mgr := NewConnectionTunnelManager(provider) + cc := &mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, + } conn := &mgmtv1alpha1.Connection{ - Id: "1", - ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }, + Id: "1", + ConnectionConfig: cc, } - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - provider.On("GetConnectionClientConfig", mock.Anything).Return(struct{}{}, nil) - provider.On("GetConnectionClient", "pgx", "postgres://foo:bar@localhost:5432/test", mock.Anything).Return(neosync_benthos_sql.NewMockSqlDbtx(t), nil) + provider.On("GetConnectionClient", cc).Return(&struct{}{}, nil) errgrp := errgroup.Group{} errgrp.Go(func() error { - _, err := mgr.GetConnection("111", conn, slog.Default()) + _, err := mgr.GetConnection("111", conn, discardLogger) return err }) errgrp.Go(func() error { - _, err := mgr.GetConnection("222", conn, slog.Default()) + _, err := mgr.GetConnection("222", conn, discardLogger) return err }) errgrp.Go(func() error { - _, err := mgr.GetConnection("333", conn, slog.Default()) + _, err := mgr.GetConnection("333", conn, discardLogger) return err }) err := errgrp.Wait() require.NoError(t, err) - provider.AssertNumberOfCalls(t, "GetConnectionDetails", 1) provider.AssertNumberOfCalls(t, "GetConnectionClient", 1) } func Test_ConnectionTunnelManager_ReleaseSession(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) + provider := NewMockConnectionProvider[any](t) mgr := NewConnectionTunnelManager(provider) require.False(t, mgr.ReleaseSession("111"), "currently no session") @@ -196,18 +82,16 @@ func Test_ConnectionTunnelManager_ReleaseSession(t *testing.T) { }, } - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - _, err := mgr.GetConnectionString("111", conn, slog.Default()) + provider.On("GetConnectionClient", mock.Anything). + Return(&struct{}{}, nil) + _, err := mgr.GetConnection("111", conn, discardLogger) require.NoError(t, err) require.True(t, mgr.ReleaseSession("111"), "released an existing session") } func Test_ConnectionTunnelManager_close(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) + provider := NewMockConnectionProvider[any](t) mgr := NewConnectionTunnelManager(provider) require.False(t, mgr.ReleaseSession("111"), "currently no session") @@ -219,31 +103,24 @@ func Test_ConnectionTunnelManager_close(t *testing.T) { }, } - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - mockDb := neosync_benthos_sql.NewMockSqlDbtx(t) - provider.On("GetConnectionClientConfig", mock.Anything).Return(struct{}{}, nil) - provider.On("GetConnectionClient", "pgx", "postgres://foo:bar@localhost:5432/test", mock.Anything).Return(mockDb, nil) + mockDb := &struct{}{} + + provider.On("GetConnectionClient", mock.Anything).Return(mockDb, nil) provider.On("CloseClientConnection", mockDb).Return(nil) - _, err := mgr.GetConnection("111", conn, slog.Default()) + _, err := mgr.GetConnection("111", conn, discardLogger) require.NoError(t, err) - require.NotEmpty(t, mgr.connDetailsMap, "has an active connection") require.NotEmpty(t, mgr.connMap, "has an active connection") mgr.close() - require.NotEmpty(t, mgr.connDetailsMap, "not empty due to active session") require.NotEmpty(t, mgr.connMap, "not empty due to active session") require.True(t, mgr.ReleaseSession("111"), "released an existing session") mgr.close() - require.Empty(t, mgr.connDetailsMap, "now empty due to no active sessions") require.Empty(t, mgr.connMap, "now empty due to no active sessions") } func Test_ConnectionTunnelManager_hardClose(t *testing.T) { - provider := NewMockConnectionProvider[any, any](t) + provider := NewMockConnectionProvider[any](t) mgr := NewConnectionTunnelManager(provider) require.False(t, mgr.ReleaseSession("111"), "currently no session") @@ -255,22 +132,15 @@ func Test_ConnectionTunnelManager_hardClose(t *testing.T) { }, } - provider.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{ - GeneralDbConnectConfig: getPgGenDbConfig(t, "foo"), - }, nil) - mockDb := neosync_benthos_sql.NewMockSqlDbtx(t) - provider.On("GetConnectionClientConfig", mock.Anything).Return(struct{}{}, nil) - provider.On("GetConnectionClient", "pgx", "postgres://foo:bar@localhost:5432/test", mock.Anything).Return(mockDb, nil) + mockDb := struct{}{} + provider.On("GetConnectionClient", mock.Anything).Return(mockDb, nil) provider.On("CloseClientConnection", mockDb).Return(nil) - _, err := mgr.GetConnection("111", conn, slog.Default()) + _, err := mgr.GetConnection("111", conn, discardLogger) require.NoError(t, err) - require.NotEmpty(t, mgr.connDetailsMap, "has an active connection") require.NotEmpty(t, mgr.connMap, "has an active connection") mgr.hardClose() - require.Empty(t, mgr.connDetailsMap, "now empty due to no active sessions") require.Empty(t, mgr.connMap, "now empty due to no active sessions") } @@ -292,73 +162,3 @@ func Test_getUniqueConnectionIdsFromSessions(t *testing.T) { require.Contains(t, output, "222") require.Contains(t, output, "333") } - -func Test_getDriverFromConnection(t *testing.T) { - t.Run("nil", func(t *testing.T) { - driver, err := getDriverFromConnection(nil) - require.Error(t, err) - require.Empty(t, driver) - }) - - t.Run("postgres", func(t *testing.T) { - driver, err := getDriverFromConnection(&mgmtv1alpha1.Connection{ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }}) - require.NoError(t, err) - require.Equal(t, "pgx", driver) - }) - - t.Run("mysql", func(t *testing.T) { - driver, err := getDriverFromConnection(&mgmtv1alpha1.Connection{ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{}, - }}) - require.NoError(t, err) - require.Equal(t, "mysql", driver) - }) - - t.Run("mssql", func(t *testing.T) { - driver, err := getDriverFromConnection(&mgmtv1alpha1.Connection{ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{}, - }}) - require.NoError(t, err) - require.Equal(t, "sqlserver", driver) - }) - - t.Run("unsupported", func(t *testing.T) { - driver, err := getDriverFromConnection(&mgmtv1alpha1.Connection{ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_AwsS3Config{}, - }}) - require.Error(t, err) - require.Empty(t, driver) - }) - - t.Run("mongo", func(t *testing.T) { - driver, err := getDriverFromConnection(&mgmtv1alpha1.Connection{ConnectionConfig: &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MongoConfig{}, - }}) - require.NoError(t, err) - require.Equal(t, "mongodb", driver) - }) -} - -func getPgGenDbConfig(t *testing.T, user string) dbconnectconfig.GeneralDbConnectConfig { - t.Helper() - dbcc, err := dbconnectconfig.NewFromPostgresConnection( - &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionConfig: &mgmtv1alpha1.PostgresConnectionConfig_Connection{ - Connection: &mgmtv1alpha1.PostgresConnection{ - Host: "localhost", - Port: int32(5432), - Name: "test", - User: user, - Pass: "bar", - }, - }, - }, - }, - nil, - ) - require.NoError(t, err) - return *dbcc -} diff --git a/worker/internal/connection-tunnel-manager/mock_ConnectionProvider.go b/worker/internal/connection-tunnel-manager/mock_ConnectionProvider.go index 4269560054..3b6c149845 100644 --- a/worker/internal/connection-tunnel-manager/mock_ConnectionProvider.go +++ b/worker/internal/connection-tunnel-manager/mock_ConnectionProvider.go @@ -5,25 +5,23 @@ package connectiontunnelmanager import ( mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" mock "github.com/stretchr/testify/mock" - - slog "log/slog" ) // MockConnectionProvider is an autogenerated mock type for the ConnectionProvider type -type MockConnectionProvider[T interface{}, TConfig interface{}] struct { +type MockConnectionProvider[T any] struct { mock.Mock } -type MockConnectionProvider_Expecter[T interface{}, TConfig interface{}] struct { +type MockConnectionProvider_Expecter[T any] struct { mock *mock.Mock } -func (_m *MockConnectionProvider[T, TConfig]) EXPECT() *MockConnectionProvider_Expecter[T, TConfig] { - return &MockConnectionProvider_Expecter[T, TConfig]{mock: &_m.Mock} +func (_m *MockConnectionProvider[T]) EXPECT() *MockConnectionProvider_Expecter[T] { + return &MockConnectionProvider_Expecter[T]{mock: &_m.Mock} } // CloseClientConnection provides a mock function with given fields: client -func (_m *MockConnectionProvider[T, TConfig]) CloseClientConnection(client T) error { +func (_m *MockConnectionProvider[T]) CloseClientConnection(client T) error { ret := _m.Called(client) if len(ret) == 0 { @@ -41,36 +39,36 @@ func (_m *MockConnectionProvider[T, TConfig]) CloseClientConnection(client T) er } // MockConnectionProvider_CloseClientConnection_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseClientConnection' -type MockConnectionProvider_CloseClientConnection_Call[T interface{}, TConfig interface{}] struct { +type MockConnectionProvider_CloseClientConnection_Call[T any] struct { *mock.Call } // CloseClientConnection is a helper method to define mock.On call // - client T -func (_e *MockConnectionProvider_Expecter[T, TConfig]) CloseClientConnection(client interface{}) *MockConnectionProvider_CloseClientConnection_Call[T, TConfig] { - return &MockConnectionProvider_CloseClientConnection_Call[T, TConfig]{Call: _e.mock.On("CloseClientConnection", client)} +func (_e *MockConnectionProvider_Expecter[T]) CloseClientConnection(client interface{}) *MockConnectionProvider_CloseClientConnection_Call[T] { + return &MockConnectionProvider_CloseClientConnection_Call[T]{Call: _e.mock.On("CloseClientConnection", client)} } -func (_c *MockConnectionProvider_CloseClientConnection_Call[T, TConfig]) Run(run func(client T)) *MockConnectionProvider_CloseClientConnection_Call[T, TConfig] { +func (_c *MockConnectionProvider_CloseClientConnection_Call[T]) Run(run func(client T)) *MockConnectionProvider_CloseClientConnection_Call[T] { _c.Call.Run(func(args mock.Arguments) { run(args[0].(T)) }) return _c } -func (_c *MockConnectionProvider_CloseClientConnection_Call[T, TConfig]) Return(_a0 error) *MockConnectionProvider_CloseClientConnection_Call[T, TConfig] { +func (_c *MockConnectionProvider_CloseClientConnection_Call[T]) Return(_a0 error) *MockConnectionProvider_CloseClientConnection_Call[T] { _c.Call.Return(_a0) return _c } -func (_c *MockConnectionProvider_CloseClientConnection_Call[T, TConfig]) RunAndReturn(run func(T) error) *MockConnectionProvider_CloseClientConnection_Call[T, TConfig] { +func (_c *MockConnectionProvider_CloseClientConnection_Call[T]) RunAndReturn(run func(T) error) *MockConnectionProvider_CloseClientConnection_Call[T] { _c.Call.Return(run) return _c } -// GetConnectionClient provides a mock function with given fields: driver, connectionString, opts -func (_m *MockConnectionProvider[T, TConfig]) GetConnectionClient(driver string, connectionString string, opts TConfig) (T, error) { - ret := _m.Called(driver, connectionString, opts) +// GetConnectionClient provides a mock function with given fields: connectionConfig +func (_m *MockConnectionProvider[T]) GetConnectionClient(connectionConfig *mgmtv1alpha1.ConnectionConfig) (T, error) { + ret := _m.Called(connectionConfig) if len(ret) == 0 { panic("no return value specified for GetConnectionClient") @@ -78,71 +76,13 @@ func (_m *MockConnectionProvider[T, TConfig]) GetConnectionClient(driver string, var r0 T var r1 error - if rf, ok := ret.Get(0).(func(string, string, TConfig) (T, error)); ok { - return rf(driver, connectionString, opts) - } - if rf, ok := ret.Get(0).(func(string, string, TConfig) T); ok { - r0 = rf(driver, connectionString, opts) - } else { - r0 = ret.Get(0).(T) - } - - if rf, ok := ret.Get(1).(func(string, string, TConfig) error); ok { - r1 = rf(driver, connectionString, opts) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockConnectionProvider_GetConnectionClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnectionClient' -type MockConnectionProvider_GetConnectionClient_Call[T interface{}, TConfig interface{}] struct { - *mock.Call -} - -// GetConnectionClient is a helper method to define mock.On call -// - driver string -// - connectionString string -// - opts TConfig -func (_e *MockConnectionProvider_Expecter[T, TConfig]) GetConnectionClient(driver interface{}, connectionString interface{}, opts interface{}) *MockConnectionProvider_GetConnectionClient_Call[T, TConfig] { - return &MockConnectionProvider_GetConnectionClient_Call[T, TConfig]{Call: _e.mock.On("GetConnectionClient", driver, connectionString, opts)} -} - -func (_c *MockConnectionProvider_GetConnectionClient_Call[T, TConfig]) Run(run func(driver string, connectionString string, opts TConfig)) *MockConnectionProvider_GetConnectionClient_Call[T, TConfig] { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(string), args[1].(string), args[2].(TConfig)) - }) - return _c -} - -func (_c *MockConnectionProvider_GetConnectionClient_Call[T, TConfig]) Return(_a0 T, _a1 error) *MockConnectionProvider_GetConnectionClient_Call[T, TConfig] { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockConnectionProvider_GetConnectionClient_Call[T, TConfig]) RunAndReturn(run func(string, string, TConfig) (T, error)) *MockConnectionProvider_GetConnectionClient_Call[T, TConfig] { - _c.Call.Return(run) - return _c -} - -// GetConnectionClientConfig provides a mock function with given fields: connectionConfig -func (_m *MockConnectionProvider[T, TConfig]) GetConnectionClientConfig(connectionConfig *mgmtv1alpha1.ConnectionConfig) (TConfig, error) { - ret := _m.Called(connectionConfig) - - if len(ret) == 0 { - panic("no return value specified for GetConnectionClientConfig") - } - - var r0 TConfig - var r1 error - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig) (TConfig, error)); ok { + if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig) (T, error)); ok { return rf(connectionConfig) } - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig) TConfig); ok { + if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig) T); ok { r0 = rf(connectionConfig) } else { - r0 = ret.Get(0).(TConfig) + r0 = ret.Get(0).(T) } if rf, ok := ret.Get(1).(func(*mgmtv1alpha1.ConnectionConfig) error); ok { @@ -154,101 +94,41 @@ func (_m *MockConnectionProvider[T, TConfig]) GetConnectionClientConfig(connecti return r0, r1 } -// MockConnectionProvider_GetConnectionClientConfig_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnectionClientConfig' -type MockConnectionProvider_GetConnectionClientConfig_Call[T interface{}, TConfig interface{}] struct { +// MockConnectionProvider_GetConnectionClient_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnectionClient' +type MockConnectionProvider_GetConnectionClient_Call[T any] struct { *mock.Call } -// GetConnectionClientConfig is a helper method to define mock.On call +// GetConnectionClient is a helper method to define mock.On call // - connectionConfig *mgmtv1alpha1.ConnectionConfig -func (_e *MockConnectionProvider_Expecter[T, TConfig]) GetConnectionClientConfig(connectionConfig interface{}) *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig] { - return &MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig]{Call: _e.mock.On("GetConnectionClientConfig", connectionConfig)} +func (_e *MockConnectionProvider_Expecter[T]) GetConnectionClient(connectionConfig interface{}) *MockConnectionProvider_GetConnectionClient_Call[T] { + return &MockConnectionProvider_GetConnectionClient_Call[T]{Call: _e.mock.On("GetConnectionClient", connectionConfig)} } -func (_c *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig]) Run(run func(connectionConfig *mgmtv1alpha1.ConnectionConfig)) *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig] { +func (_c *MockConnectionProvider_GetConnectionClient_Call[T]) Run(run func(connectionConfig *mgmtv1alpha1.ConnectionConfig)) *MockConnectionProvider_GetConnectionClient_Call[T] { _c.Call.Run(func(args mock.Arguments) { run(args[0].(*mgmtv1alpha1.ConnectionConfig)) }) return _c } -func (_c *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig]) Return(_a0 TConfig, _a1 error) *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig] { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig]) RunAndReturn(run func(*mgmtv1alpha1.ConnectionConfig) (TConfig, error)) *MockConnectionProvider_GetConnectionClientConfig_Call[T, TConfig] { - _c.Call.Return(run) - return _c -} - -// GetConnectionDetails provides a mock function with given fields: connectionConfig, connectionTimeout, logger -func (_m *MockConnectionProvider[T, TConfig]) GetConnectionDetails(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger) (ConnectionDetails, error) { - ret := _m.Called(connectionConfig, connectionTimeout, logger) - - if len(ret) == 0 { - panic("no return value specified for GetConnectionDetails") - } - - var r0 ConnectionDetails - var r1 error - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig, *uint32, *slog.Logger) (ConnectionDetails, error)); ok { - return rf(connectionConfig, connectionTimeout, logger) - } - if rf, ok := ret.Get(0).(func(*mgmtv1alpha1.ConnectionConfig, *uint32, *slog.Logger) ConnectionDetails); ok { - r0 = rf(connectionConfig, connectionTimeout, logger) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(ConnectionDetails) - } - } - - if rf, ok := ret.Get(1).(func(*mgmtv1alpha1.ConnectionConfig, *uint32, *slog.Logger) error); ok { - r1 = rf(connectionConfig, connectionTimeout, logger) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockConnectionProvider_GetConnectionDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetConnectionDetails' -type MockConnectionProvider_GetConnectionDetails_Call[T interface{}, TConfig interface{}] struct { - *mock.Call -} - -// GetConnectionDetails is a helper method to define mock.On call -// - connectionConfig *mgmtv1alpha1.ConnectionConfig -// - connectionTimeout *uint32 -// - logger *slog.Logger -func (_e *MockConnectionProvider_Expecter[T, TConfig]) GetConnectionDetails(connectionConfig interface{}, connectionTimeout interface{}, logger interface{}) *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig] { - return &MockConnectionProvider_GetConnectionDetails_Call[T, TConfig]{Call: _e.mock.On("GetConnectionDetails", connectionConfig, connectionTimeout, logger)} -} - -func (_c *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig]) Run(run func(connectionConfig *mgmtv1alpha1.ConnectionConfig, connectionTimeout *uint32, logger *slog.Logger)) *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig] { - _c.Call.Run(func(args mock.Arguments) { - run(args[0].(*mgmtv1alpha1.ConnectionConfig), args[1].(*uint32), args[2].(*slog.Logger)) - }) - return _c -} - -func (_c *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig]) Return(_a0 ConnectionDetails, _a1 error) *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig] { +func (_c *MockConnectionProvider_GetConnectionClient_Call[T]) Return(_a0 T, _a1 error) *MockConnectionProvider_GetConnectionClient_Call[T] { _c.Call.Return(_a0, _a1) return _c } -func (_c *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig]) RunAndReturn(run func(*mgmtv1alpha1.ConnectionConfig, *uint32, *slog.Logger) (ConnectionDetails, error)) *MockConnectionProvider_GetConnectionDetails_Call[T, TConfig] { +func (_c *MockConnectionProvider_GetConnectionClient_Call[T]) RunAndReturn(run func(*mgmtv1alpha1.ConnectionConfig) (T, error)) *MockConnectionProvider_GetConnectionClient_Call[T] { _c.Call.Return(run) return _c } // NewMockConnectionProvider creates a new instance of MockConnectionProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewMockConnectionProvider[T interface{}, TConfig interface{}](t interface { +func NewMockConnectionProvider[T any](t interface { mock.TestingT Cleanup(func()) -}) *MockConnectionProvider[T, TConfig] { - mock := &MockConnectionProvider[T, TConfig]{} +}) *MockConnectionProvider[T] { + mock := &MockConnectionProvider[T]{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) diff --git a/worker/internal/connection-tunnel-manager/providers/mongoprovider/provider.go b/worker/internal/connection-tunnel-manager/providers/mongoprovider/provider.go index 1687911f6b..162ad41b91 100644 --- a/worker/internal/connection-tunnel-manager/providers/mongoprovider/provider.go +++ b/worker/internal/connection-tunnel-manager/providers/mongoprovider/provider.go @@ -2,11 +2,9 @@ package mongoprovider import ( "context" - "log/slog" + "errors" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/nucleuscloud/neosync/backend/pkg/clienttls" - "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect" connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager" neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb" "go.mongodb.org/mongo-driver/mongo" @@ -19,20 +17,17 @@ func NewProvider() *Provider { return &Provider{} } -var _ connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient, any] = &Provider{} - -func (p *Provider) GetConnectionDetails( - cc *mgmtv1alpha1.ConnectionConfig, - connectionTimeout *uint32, - logger *slog.Logger, -) (connectiontunnelmanager.ConnectionDetails, error) { - return mongoconnect.GetConnectionDetails(cc, clienttls.UpsertCLientTlsFiles, logger) -} +var _ connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient] = &Provider{} // this is currently untested as it isn't really used anywhere -func (p *Provider) GetConnectionClient(driver, connectionString string, opts any) (neosync_benthos_mongodb.MongoClient, error) { +func (p *Provider) GetConnectionClient(cc *mgmtv1alpha1.ConnectionConfig) (neosync_benthos_mongodb.MongoClient, error) { + connStr := cc.GetMongoConfig().GetUrl() + if connStr == "" { + return nil, errors.New("unable to find mongodb url on connection config") + } + serverAPI := options.ServerAPI(options.ServerAPIVersion1) - opts2 := options.Client().ApplyURI(connectionString).SetServerAPIOptions(serverAPI) + opts2 := options.Client().ApplyURI(connStr).SetServerAPIOptions(serverAPI) client, err := mongo.Connect(context.Background(), opts2) if err != nil { @@ -44,8 +39,3 @@ func (p *Provider) GetConnectionClient(driver, connectionString string, opts any func (p *Provider) CloseClientConnection(client neosync_benthos_mongodb.MongoClient) error { return client.Disconnect(context.Background()) } - -func (p *Provider) GetConnectionClientConfig(cc *mgmtv1alpha1.ConnectionConfig) (any, error) { - var result any - return result, nil -} diff --git a/worker/internal/connection-tunnel-manager/providers/provider.go b/worker/internal/connection-tunnel-manager/providers/provider.go index d78cd1d32d..e8c647e39a 100644 --- a/worker/internal/connection-tunnel-manager/providers/provider.go +++ b/worker/internal/connection-tunnel-manager/providers/provider.go @@ -2,25 +2,23 @@ package providers import ( "fmt" - "log/slog" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager" - "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager/providers/sqlprovider" neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb" neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" ) type Provider struct { - mp connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient, any] - sp connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig] + mp connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient] + sp connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx] } -var _ connectiontunnelmanager.ConnectionProvider[any, any] = &Provider{} +var _ connectiontunnelmanager.ConnectionProvider[any] = &Provider{} func NewProvider( - mp connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient, any], - sp connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig], + mp connectiontunnelmanager.ConnectionProvider[neosync_benthos_mongodb.MongoClient], + sp connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx], ) *Provider { return &Provider{ mp: mp, @@ -28,38 +26,20 @@ func NewProvider( } } -func (p *Provider) GetConnectionDetails( - cc *mgmtv1alpha1.ConnectionConfig, - connectionTimeout *uint32, - logger *slog.Logger, -) (connectiontunnelmanager.ConnectionDetails, error) { +func (p *Provider) GetConnectionClient(cc *mgmtv1alpha1.ConnectionConfig) (any, error) { if cc == nil { cc = &mgmtv1alpha1.ConnectionConfig{} } switch cc.GetConfig().(type) { case *mgmtv1alpha1.ConnectionConfig_MongoConfig: - return p.mp.GetConnectionDetails(cc, connectionTimeout, logger) + return p.mp.GetConnectionClient(cc) case *mgmtv1alpha1.ConnectionConfig_MysqlConfig, *mgmtv1alpha1.ConnectionConfig_PgConfig, *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - return p.sp.GetConnectionDetails(cc, connectionTimeout, logger) + return p.sp.GetConnectionClient(cc) default: return nil, fmt.Errorf("unsupported connection config: %T", cc.GetConfig()) } } -func (p *Provider) GetConnectionClient(driver, connectionString string, opts any) (any, error) { - switch driver { - case "mysql", "postgres", "postgresql", "sqlserver", "pgx": - typedopts, ok := opts.(*sqlprovider.ConnectionClientConfig) - if !ok { - return nil, fmt.Errorf("opts was not *sqlprovider.ConnectionClientConfig, was %T", opts) - } - return p.sp.GetConnectionClient(driver, connectionString, typedopts) - case "mongodb", "mongodb+srv": - return p.mp.GetConnectionClient(driver, connectionString, opts) - } - return nil, fmt.Errorf("unsupported driver: %s", driver) -} - func (p *Provider) CloseClientConnection(client any) error { switch typedclient := client.(type) { case neosync_benthos_sql.SqlDbtx: @@ -70,14 +50,3 @@ func (p *Provider) CloseClientConnection(client any) error { return fmt.Errorf("unsupported client, unable to close connection: %T", client) } } - -func (p *Provider) GetConnectionClientConfig(cc *mgmtv1alpha1.ConnectionConfig) (any, error) { - switch cc.GetConfig().(type) { - case *mgmtv1alpha1.ConnectionConfig_MongoConfig: - return p.mp.GetConnectionClientConfig(cc) - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig, *mgmtv1alpha1.ConnectionConfig_PgConfig, *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - return p.sp.GetConnectionClientConfig(cc) - default: - return nil, fmt.Errorf("unsupported connection config: %T", cc.GetConfig()) - } -} diff --git a/worker/internal/connection-tunnel-manager/providers/provider_test.go b/worker/internal/connection-tunnel-manager/providers/provider_test.go index 8266d561ba..cc8acf6489 100644 --- a/worker/internal/connection-tunnel-manager/providers/provider_test.go +++ b/worker/internal/connection-tunnel-manager/providers/provider_test.go @@ -4,10 +4,7 @@ import ( "testing" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/nucleuscloud/neosync/backend/pkg/mongoconnect" - "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager" - "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager/providers/sqlprovider" neosync_benthos_mongodb "github.com/nucleuscloud/neosync/worker/pkg/benthos/mongodb" neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" "github.com/stretchr/testify/mock" @@ -16,103 +13,34 @@ import ( ) func Test_NewProvider(t *testing.T) { - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) require.NotNil(t, NewProvider(mockMp, mockSp)) } -func Test_Provider_GetConnectionDetails(t *testing.T) { - t.Run("mongo", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - mockMp.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&mongoconnect.ConnectionDetails{}, nil) - - result, err := provider.GetConnectionDetails(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MongoConfig{}, - }, nil, nil) - require.NoError(t, err) - require.NotNil(t, result) - }) - - t.Run("postgres", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - mockSp.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{}, nil) - - result, err := provider.GetConnectionDetails(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }, nil, nil) - require.NoError(t, err) - require.NotNil(t, result) - }) - - t.Run("mysql", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - mockSp.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{}, nil) - - result, err := provider.GetConnectionDetails(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{}, - }, nil, nil) - require.NoError(t, err) - require.NotNil(t, result) - }) - - t.Run("mssql", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - mockSp.On("GetConnectionDetails", mock.Anything, mock.Anything, mock.Anything). - Return(&sqlconnect.ConnectionDetails{}, nil) - - result, err := provider.GetConnectionDetails(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{}, - }, nil, nil) - require.NoError(t, err) - require.NotNil(t, result) - }) -} - func Test_Provider_GetConnectionClient(t *testing.T) { t.Run("mongo", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) provider := NewProvider(mockMp, mockSp) mockMp.On("GetConnectionClient", mock.Anything, mock.Anything, mock.Anything). Return(&mongo.Client{}, nil) - var opts any = struct{}{} - result, err := provider.GetConnectionClient("mongodb", "test-str", opts) + result, err := provider.GetConnectionClient(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_MongoConfig{}, + }) require.NoError(t, err) require.NotNil(t, result) }) t.Run("postgres", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) mockDbtx := neosync_benthos_sql.NewMockSqlDbtx(t) provider := NewProvider(mockMp, mockSp) @@ -120,29 +48,17 @@ func Test_Provider_GetConnectionClient(t *testing.T) { mockSp.On("GetConnectionClient", mock.Anything, mock.Anything, mock.Anything). Return(mockDbtx, nil) - opts := &sqlprovider.ConnectionClientConfig{} - result, err := provider.GetConnectionClient("postgres", "test-str", opts) + result, err := provider.GetConnectionClient(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, + }) require.NoError(t, err) require.NotNil(t, result) }) - t.Run("postgres-bad", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - var opts any = struct{}{} - result, err := provider.GetConnectionClient("postgres", "test-str", opts) - require.Error(t, err) - require.Nil(t, result) - }) - t.Run("mysql", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) mockDbtx := neosync_benthos_sql.NewMockSqlDbtx(t) provider := NewProvider(mockMp, mockSp) @@ -150,29 +66,17 @@ func Test_Provider_GetConnectionClient(t *testing.T) { mockSp.On("GetConnectionClient", mock.Anything, mock.Anything, mock.Anything). Return(mockDbtx, nil) - opts := &sqlprovider.ConnectionClientConfig{} - result, err := provider.GetConnectionClient("mysql", "test-str", opts) + result, err := provider.GetConnectionClient(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{}, + }) require.NoError(t, err) require.NotNil(t, result) }) - t.Run("mysql-bad", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - var opts any = struct{}{} - result, err := provider.GetConnectionClient("mysql", "test-str", opts) - require.Error(t, err) - require.Nil(t, result) - }) - t.Run("mssql", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) mockDbtx := neosync_benthos_sql.NewMockSqlDbtx(t) provider := NewProvider(mockMp, mockSp) @@ -180,31 +84,19 @@ func Test_Provider_GetConnectionClient(t *testing.T) { mockSp.On("GetConnectionClient", mock.Anything, mock.Anything, mock.Anything). Return(mockDbtx, nil) - opts := &sqlprovider.ConnectionClientConfig{} - result, err := provider.GetConnectionClient("sqlserver", "test-str", opts) + result, err := provider.GetConnectionClient(&mgmtv1alpha1.ConnectionConfig{ + Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{}, + }) require.NoError(t, err) require.NotNil(t, result) }) - - t.Run("mssql-bad", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - var opts any = struct{}{} - result, err := provider.GetConnectionClient("sqlserver", "test-str", opts) - require.Error(t, err) - require.Nil(t, result) - }) } func Test_Provider_CloseClientConnection(t *testing.T) { t.Run("mongo", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) provider := NewProvider(mockMp, mockSp) @@ -217,8 +109,8 @@ func Test_Provider_CloseClientConnection(t *testing.T) { t.Run("sql", func(t *testing.T) { t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) + mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient](t) + mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx](t) mockDbtx := neosync_benthos_sql.NewMockSqlDbtx(t) provider := NewProvider(mockMp, mockSp) @@ -229,81 +121,3 @@ func Test_Provider_CloseClientConnection(t *testing.T) { require.NoError(t, err) }) } - -func Test_Provider_GetConnectionClientConfig(t *testing.T) { - t.Run("mongo", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - cc := &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MongoConfig{}, - } - - var result any = struct{}{} - mockMp.On("GetConnectionClientConfig", cc).Return(result, nil) - - config, err := provider.GetConnectionClientConfig(cc) - require.NoError(t, err) - require.Equal(t, result, config) - }) - - t.Run("postgres", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - cc := &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - } - - result := &sqlprovider.ConnectionClientConfig{} - mockSp.On("GetConnectionClientConfig", cc).Return(result, nil) - - config, err := provider.GetConnectionClientConfig(cc) - require.NoError(t, err) - require.Equal(t, result, config) - }) - - t.Run("mysql", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - cc := &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{}, - } - - result := &sqlprovider.ConnectionClientConfig{} - mockSp.On("GetConnectionClientConfig", cc).Return(result, nil) - - config, err := provider.GetConnectionClientConfig(cc) - require.NoError(t, err) - require.Equal(t, result, config) - }) - - t.Run("mssql", func(t *testing.T) { - t.Parallel() - mockMp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_mongodb.MongoClient, any](t) - mockSp := connectiontunnelmanager.NewMockConnectionProvider[neosync_benthos_sql.SqlDbtx, *sqlprovider.ConnectionClientConfig](t) - - provider := NewProvider(mockMp, mockSp) - - cc := &mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{}, - } - - result := &sqlprovider.ConnectionClientConfig{} - mockSp.On("GetConnectionClientConfig", cc).Return(result, nil) - - config, err := provider.GetConnectionClientConfig(cc) - require.NoError(t, err) - require.Equal(t, result, config) - }) -} diff --git a/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider.go b/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider.go index bfa0e1c758..f2b6d26f1f 100644 --- a/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider.go +++ b/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider.go @@ -1,77 +1,49 @@ package sqlprovider import ( - "database/sql" "log/slog" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/nucleuscloud/neosync/backend/pkg/clienttls" "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager" neosync_benthos_sql "github.com/nucleuscloud/neosync/worker/pkg/benthos/sql" ) -type Provider struct{} +type Provider struct { + connector sqlconnect.SqlConnector +} -func NewProvider() *Provider { - return &Provider{} +func NewProvider( + sqlconnector sqlconnect.SqlConnector, +) *Provider { + return &Provider{connector: sqlconnector} } -var _ connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx, *ConnectionClientConfig] = &Provider{} +var _ connectiontunnelmanager.ConnectionProvider[neosync_benthos_sql.SqlDbtx] = &Provider{} -func (p *Provider) GetConnectionDetails( - cc *mgmtv1alpha1.ConnectionConfig, - connectionTimeout *uint32, - logger *slog.Logger, -) (connectiontunnelmanager.ConnectionDetails, error) { - return sqlconnect.GetConnectionDetails(cc, connectionTimeout, clienttls.UpsertCLientTlsFiles, logger) +type sqlDbtxWrapper struct { + sqlconnect.SqlDBTX + close func() error } -type ConnectionClientConfig struct { - MaxConnectionLimit *int32 +func (s *sqlDbtxWrapper) Close() error { + return s.close() } -func (p *Provider) GetConnectionClient(driver, connectionString string, opts *ConnectionClientConfig) (neosync_benthos_sql.SqlDbtx, error) { - db, err := sql.Open(driver, connectionString) +func (p *Provider) GetConnectionClient(cc *mgmtv1alpha1.ConnectionConfig) (neosync_benthos_sql.SqlDbtx, error) { + container, err := p.connector.NewDbFromConnectionConfig(cc, nil, slog.Default()) if err != nil { return nil, err } - if opts != nil && opts.MaxConnectionLimit != nil { - db.SetMaxOpenConns(int(*opts.MaxConnectionLimit)) + dbtx, err := container.Open() + if err != nil { + return nil, err } - return db, nil + return &sqlDbtxWrapper{SqlDBTX: dbtx, close: func() error { + return container.Close() + }}, nil } func (p *Provider) CloseClientConnection(client neosync_benthos_sql.SqlDbtx) error { return client.Close() } - -func (p *Provider) GetConnectionClientConfig(cc *mgmtv1alpha1.ConnectionConfig) (*ConnectionClientConfig, error) { - return &ConnectionClientConfig{ - MaxConnectionLimit: getMaxConnectionLimitFromConnection(cc), - }, nil -} - -func getMaxConnectionLimitFromConnection(cc *mgmtv1alpha1.ConnectionConfig) *int32 { - if cc == nil { - return nil - } - switch config := cc.GetConfig().(type) { - case *mgmtv1alpha1.ConnectionConfig_MysqlConfig: - if config.MysqlConfig != nil && config.MysqlConfig.ConnectionOptions != nil { - return config.MysqlConfig.ConnectionOptions.MaxConnectionLimit - } - return nil - case *mgmtv1alpha1.ConnectionConfig_PgConfig: - if config.PgConfig != nil && config.PgConfig.ConnectionOptions != nil { - return config.PgConfig.ConnectionOptions.MaxConnectionLimit - } - return nil - case *mgmtv1alpha1.ConnectionConfig_MssqlConfig: - if config.MssqlConfig != nil && config.MssqlConfig.GetConnectionOptions() != nil { - return config.MssqlConfig.GetConnectionOptions().MaxConnectionLimit - } - return nil - } - return nil -} diff --git a/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider_test.go b/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider_test.go deleted file mode 100644 index 22395e640f..0000000000 --- a/worker/internal/connection-tunnel-manager/providers/sqlprovider/provider_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package sqlprovider - -import ( - "testing" - - mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" - "github.com/stretchr/testify/require" -) - -func Test_getMaxConnectionLimitFromConnection(t *testing.T) { - var nilInt32 *int32 - maxConnLimit := int32(50) - - t.Run("postgres", func(t *testing.T) { - actual := getMaxConnectionLimitFromConnection(nil) - require.Empty(t, actual) - - actual = getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{}, - }) - require.Equal(t, nilInt32, actual) - - actual = getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_PgConfig{ - PgConfig: &mgmtv1alpha1.PostgresConnectionConfig{ - ConnectionOptions: &mgmtv1alpha1.SqlConnectionOptions{ - MaxConnectionLimit: &maxConnLimit, - }, - }, - }, - }) - require.Equal(t, &maxConnLimit, actual) - }) - - t.Run("mysql", func(t *testing.T) { - actual := getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{}, - }) - require.Equal(t, nilInt32, actual) - - actual = getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MysqlConfig{ - MysqlConfig: &mgmtv1alpha1.MysqlConnectionConfig{ - ConnectionOptions: &mgmtv1alpha1.SqlConnectionOptions{ - MaxConnectionLimit: &maxConnLimit, - }, - }, - }, - }) - require.Equal(t, &maxConnLimit, actual) - }) - - t.Run("mssql", func(t *testing.T) { - actual := getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{}, - }) - require.Equal(t, nilInt32, actual) - - actual = getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_MssqlConfig{ - MssqlConfig: &mgmtv1alpha1.MssqlConnectionConfig{ - ConnectionOptions: &mgmtv1alpha1.SqlConnectionOptions{ - MaxConnectionLimit: &maxConnLimit, - }, - }, - }, - }) - require.Equal(t, &maxConnLimit, actual) - }) - - t.Run("awss3", func(t *testing.T) { - actual := getMaxConnectionLimitFromConnection(&mgmtv1alpha1.ConnectionConfig{ - Config: &mgmtv1alpha1.ConnectionConfig_AwsS3Config{}, - }) - require.Empty(t, actual) - }) -} diff --git a/worker/pkg/workflows/datasync/activities/sync/activity.go b/worker/pkg/workflows/datasync/activities/sync/activity.go index c2e36bb07c..624ab99ec3 100644 --- a/worker/pkg/workflows/datasync/activities/sync/activity.go +++ b/worker/pkg/workflows/datasync/activities/sync/activity.go @@ -16,10 +16,10 @@ import ( _ "github.com/warpstreamlabs/bento/public/components/pure" _ "github.com/warpstreamlabs/bento/public/components/pure/extended" _ "github.com/warpstreamlabs/bento/public/components/redis" - _ "github.com/warpstreamlabs/bento/public/components/sql" neosynclogger "github.com/nucleuscloud/neosync/backend/pkg/logger" "github.com/nucleuscloud/neosync/backend/pkg/metrics" + "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" connectiontunnelmanager "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager" "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager/providers" "github.com/nucleuscloud/neosync/worker/internal/connection-tunnel-manager/providers/mongoprovider" @@ -65,6 +65,7 @@ type SyncResponse struct { func New( connclient mgmtv1alpha1connect.ConnectionServiceClient, jobclient mgmtv1alpha1connect.JobServiceClient, + sqlconnector sqlconnect.SqlConnector, tunnelmanagermap *sync.Map, temporalclient client.Client, meter metric.Meter, @@ -74,6 +75,7 @@ func New( return &Activity{ connclient: connclient, jobclient: jobclient, + sqlconnector: sqlconnector, tunnelmanagermap: tunnelmanagermap, temporalclient: temporalclient, meter: meter, @@ -83,6 +85,7 @@ func New( } type Activity struct { + sqlconnector sqlconnect.SqlConnector connclient mgmtv1alpha1connect.ConnectionServiceClient jobclient mgmtv1alpha1connect.JobServiceClient tunnelmanagermap *sync.Map @@ -95,9 +98,9 @@ type Activity struct { func (a *Activity) getTunnelManagerByRunId(wfId, runId string) (connectiontunnelmanager.Interface[any], error) { connectionProvider := providers.NewProvider( mongoprovider.NewProvider(), - sqlprovider.NewProvider(), + sqlprovider.NewProvider(a.sqlconnector), ) - val, loaded := a.tunnelmanagermap.LoadOrStore(runId, connectiontunnelmanager.NewConnectionTunnelManager[any, any](connectionProvider)) + val, loaded := a.tunnelmanagermap.LoadOrStore(runId, connectiontunnelmanager.NewConnectionTunnelManager[any](connectionProvider)) manager, ok := val.(connectiontunnelmanager.Interface[any]) if !ok { return nil, fmt.Errorf("unable to retrieve connection tunnel manager from tunnel manager map. Expected *ConnectionTunnelManager, received: %T", manager) @@ -249,14 +252,8 @@ func (a *Activity) Sync(ctx context.Context, req *SyncRequest, metadata *SyncMet bdns := bdns errgrp.Go(func() error { connection := connections[idx] - // benthos raws will need to have a map of connetions due to there possibly being more than one connection per benthos run associated to the configs - // so the raws need to have connections that will be good for every connection string it will encounter in a single run - localConnStr, err := tunnelmanager.GetConnectionString(session, connection, slogger) - if err != nil { - return err - } - envKeyDsnSyncMap.Store(bdns.EnvVarKey, localConnStr) - dsnToConnectionIdMap.Store(localConnStr, connection.Id) + envKeyDsnSyncMap.Store(bdns.EnvVarKey, connection.Id) + dsnToConnectionIdMap.Store(connection.Id, connection.Id) return nil }) } diff --git a/worker/pkg/workflows/datasync/activities/sync/activity_test.go b/worker/pkg/workflows/datasync/activities/sync/activity_test.go index 6657e5d88d..716c80a2ed 100644 --- a/worker/pkg/workflows/datasync/activities/sync/activity_test.go +++ b/worker/pkg/workflows/datasync/activities/sync/activity_test.go @@ -16,6 +16,7 @@ import ( "github.com/google/uuid" mgmtv1alpha1 "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1" "github.com/nucleuscloud/neosync/backend/gen/go/protos/mgmt/v1alpha1/mgmtv1alpha1connect" + "github.com/nucleuscloud/neosync/backend/pkg/sqlconnect" "github.com/nucleuscloud/neosync/worker/pkg/workflows/datasync/activities/shared" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" @@ -59,7 +60,7 @@ output: jobclient := mgmtv1alpha1connect.NewJobServiceClient(srv.Client(), srv.URL) - activity := New(nil, jobclient, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, jobclient, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -79,7 +80,7 @@ func Test_Sync_Run_No_BenthosConfig(t *testing.T) { benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -93,7 +94,7 @@ func Test_Sync_Run_Success(t *testing.T) { env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -123,7 +124,7 @@ func Test_Sync_Run_Metrics_Success(t *testing.T) { meterProvider := metricsdk.NewMeterProvider() meter := meterProvider.Meter("test") benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, meter, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, meter, benthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -153,7 +154,7 @@ func Test_Sync_Fake_Mutation_Success(t *testing.T) { env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) val, err := env.ExecuteActivity(activity.Sync, &SyncRequest{ @@ -185,7 +186,7 @@ func Test_Sync_Run_Success_Javascript(t *testing.T) { env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) tmpFile, err := os.CreateTemp("", "test") @@ -241,7 +242,7 @@ func Test_Sync_Run_Success_MutataionAndJavascript(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) tmpFile, err := os.CreateTemp("", "test") @@ -300,7 +301,7 @@ func Test_Sync_Run_Processor_Error(t *testing.T) { env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -331,7 +332,7 @@ func Test_Sync_Run_Output_Error(t *testing.T) { mockBenthosStreamManager := NewMockBenthosStreamManagerClient(t) mockBenthosStream := NewMockBenthosStreamClient(t) - activity := New(nil, nil, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) env.RegisterActivity(activity.Sync) @@ -382,7 +383,7 @@ output: mockBenthosStream.On("Run", mock.Anything).After(5 * time.Second).Return(nil) mockBenthosStream.On("StopWithin", mock.Anything).Return(nil) - activity := New(nil, nil, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) env.RegisterActivity(activity.Sync) stopCh := make(chan struct{}) @@ -405,7 +406,7 @@ func Test_Sync_Run_ActivityWorkerStop(t *testing.T) { testSuite := &testsuite.WorkflowTestSuite{} env := testSuite.NewTestActivityEnvironment() benthosStreamManager := NewBenthosStreamManager() - activity := New(nil, nil, &sync.Map{}, nil, nil, benthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, benthosStreamManager, true) env.RegisterActivity(activity.Sync) stopCh := make(chan struct{}) @@ -456,7 +457,7 @@ output: mockBenthosStream.On("Run", mock.Anything).Return(errors.New(errmsg)) mockBenthosStream.On("StopWithin", mock.Anything).Return(nil).Maybe() - activity := New(nil, nil, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) + activity := New(nil, nil, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, nil, nil, mockBenthosStreamManager, true) env.RegisterActivity(activity.Sync) _, err := env.ExecuteActivity(activity.Sync, &SyncRequest{ diff --git a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go index eeae0ecfa1..3af8e4e0aa 100644 --- a/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go +++ b/worker/pkg/workflows/datasync/workflow/workflow_integration_test.go @@ -1539,7 +1539,7 @@ func executeWorkflow( ) var activityMeter metric.Meter disableReaper := true - syncActivity := sync_activity.New(connclient, jobclient, &sync.Map{}, temporalClientMock, activityMeter, sync_activity.NewBenthosStreamManager(), disableReaper) + syncActivity := sync_activity.New(connclient, jobclient, &sqlconnect.SqlOpenConnector{}, &sync.Map{}, temporalClientMock, activityMeter, sync_activity.NewBenthosStreamManager(), disableReaper) retrieveActivityOpts := syncactivityopts_activity.New(jobclient) runSqlInitTableStatements := runsqlinittablestmts_activity.New(jobclient, connclient, sqlmanager) accountStatusActivity := accountstatus_activity.New(userclient)