From 3c018be920e7e325b8eb0c29e02e87027e06c58b Mon Sep 17 00:00:00 2001 From: Horacio Duran Date: Thu, 26 Jul 2018 14:43:30 -0300 Subject: [PATCH] Add various flavors of Join to the chain --- db/chain/chain.go | 74 +++++++++++++++++++++++++++++++++++++----- db/chain/chain_test.go | 17 ++++++++++ db/chain/helpers.go | 7 ++++ db/chain/segment.go | 26 ++++++++------- 4 files changed, 105 insertions(+), 19 deletions(-) diff --git a/db/chain/chain.go b/db/chain/chain.go index 04e3ea9..d57998a 100644 --- a/db/chain/chain.go +++ b/db/chain/chain.go @@ -379,6 +379,62 @@ func (ec *ExpresionChain) Join(expr string, args ...interface{}) *ExpresionChain return ec } +// LeftJoin adds a 'LEFT JOIN' to the 'ExpresionChain' and returns the same chan to facilitate +// further chaining. +// THIS DOES NOT CREATE A COPY OF THE CHAIN, IT MUTATES IN PLACE. +func (ec *ExpresionChain) LeftJoin(expr string, args ...interface{}) *ExpresionChain { + ec.append( + querySegmentAtom{ + segment: sqlLeftJoin, + expresion: expr, + arguments: args, + sqlBool: SQLNothing, + }) + return ec +} + +// RightJoin adds a 'RIGHT JOIN' to the 'ExpresionChain' and returns the same chan to facilitate +// further chaining. +// THIS DOES NOT CREATE A COPY OF THE CHAIN, IT MUTATES IN PLACE. +func (ec *ExpresionChain) RightJoin(expr string, args ...interface{}) *ExpresionChain { + ec.append( + querySegmentAtom{ + segment: sqlRightJoin, + expresion: expr, + arguments: args, + sqlBool: SQLNothing, + }) + return ec +} + +// InnerJoin adds a 'INNER JOIN' to the 'ExpresionChain' and returns the same chan to facilitate +// further chaining. +// THIS DOES NOT CREATE A COPY OF THE CHAIN, IT MUTATES IN PLACE. +func (ec *ExpresionChain) InnerJoin(expr string, args ...interface{}) *ExpresionChain { + ec.append( + querySegmentAtom{ + segment: sqlInnerJoin, + expresion: expr, + arguments: args, + sqlBool: SQLNothing, + }) + return ec +} + +// OuterJoin adds a 'OUTER JOIN' to the 'ExpresionChain' and returns the same chan to facilitate +// further chaining. +// THIS DOES NOT CREATE A COPY OF THE CHAIN, IT MUTATES IN PLACE. +func (ec *ExpresionChain) OuterJoin(expr string, args ...interface{}) *ExpresionChain { + ec.append( + querySegmentAtom{ + segment: sqlOuterJoin, + expresion: expr, + arguments: args, + sqlBool: SQLNothing, + }) + return ec +} + // OrderBy adds a 'ORDER BY' to the 'ExpresionChain' and returns the same chan to facilitate // further chaining. // THIS DOES NOT CREATE A COPY OF THE CHAIN, IT MUTATES IN PLACE. @@ -640,16 +696,18 @@ func (ec *ExpresionChain) render(raw bool) (string, []interface{}, error) { ec.mainOperation.segment == sqlUpdate { // JOIN joins := extract(ec, sqlJoin) + joins = append(joins, extract(ec, sqlLeftJoin)...) + joins = append(joins, extract(ec, sqlRightJoin)...) + joins = append(joins, extract(ec, sqlInnerJoin)...) + joins = append(joins, extract(ec, sqlOuterJoin)...) if len(joins) != 0 { - joinSubQueries := make([]string, len(joins)) - joinArguments := []interface{}{} - for i, item := range joins { - joinSubQueries[i] = item.expresion - joinArguments = append(joinArguments, item.arguments...) + for _, join := range joins { + + query += fmt.Sprintf(" %s %s", + join.segment, + join.expresion) + args = append(args, join.arguments...) } - query += fmt.Sprintf(" JOIN %s", - strings.Join(joinSubQueries, " ")) - args = append(args, joinArguments...) } } diff --git a/db/chain/chain_test.go b/db/chain/chain_test.go index 710931b..7189069 100644 --- a/db/chain/chain_test.go +++ b/db/chain/chain_test.go @@ -155,6 +155,23 @@ func TestExpresionChain_Render(t *testing.T) { wantArgs: []interface{}{"unpirulo", 1, 2, "pajarito"}, wantErr: false, }, + { + name: "basic selection with flavors of JOIN", + chain: (&ExpresionChain{}).Select("field1", "field2", "field3"). + Table("convenient_table"). + AndWhere("field1 > ?", 1). + AndWhere("field2 = ?", 2). + AndWhere("field3 > ?", "pajarito"). + Join(JoinOn("another_convenient_table", "pirulo = ?", "unpirulo")). + Join(JoinOn("yet_another_convenient_table", "pirulo = ?", "otrounpirulo")). + LeftJoin(JoinOn("one_convenient_table", "pirulo2 = ?", "dospirulo")). + RightJoin(JoinOn("three_convenient_table", "pirulo3 = ?", "trespirulo")). + InnerJoin(JoinOn("four_convenient_table", "pirulo4 = ?", "cuatropirulo")). + OuterJoin(JoinOn("five_convenient_table", "pirulo5 = ?", "cincopirulo")), + want: "SELECT field1, field2, field3 FROM convenient_table JOIN another_convenient_table ON pirulo = $1 JOIN yet_another_convenient_table ON pirulo = $2 LEFT JOIN one_convenient_table ON pirulo2 = $3 RIGHT JOIN three_convenient_table ON pirulo3 = $4 INNER JOIN four_convenient_table ON pirulo4 = $5 OUTER JOIN five_convenient_table ON pirulo5 = $6 WHERE field1 > $7 AND field2 = $8 AND field3 > $9", + wantArgs: []interface{}{"unpirulo", "otrounpirulo", "dospirulo", "trespirulo", "cuatropirulo", "cincopirulo", 1, 2, "pajarito"}, + wantErr: false, + }, { name: "basic selection with where and join and group by", chain: (&ExpresionChain{}).Select("field1", "field2", "field3"). diff --git a/db/chain/helpers.go b/db/chain/helpers.go index 2b2d6c3..9d5be5d 100644 --- a/db/chain/helpers.go +++ b/db/chain/helpers.go @@ -103,3 +103,10 @@ func Null(field string) string { func SetToCurrentTimestamp(field string) string { return fmt.Sprintf("%s = %s", field, CurrentTimestampPGFn) } + +// JOIN helpers + +// JoinOn crafts the `table ON expression` +func JoinOn(table, expr string, args ...interface{}) (string, []interface{}) { + return fmt.Sprintf("%s ON %s", table, expr), args +} diff --git a/db/chain/segment.go b/db/chain/segment.go index 92f569b..58458cd 100644 --- a/db/chain/segment.go +++ b/db/chain/segment.go @@ -39,17 +39,21 @@ const ( type sqlSegment string const ( - sqlWhere sqlSegment = "WHERE" - sqlLimit sqlSegment = "LIMIT" - sqlOffset sqlSegment = "OFFSET" - sqlJoin sqlSegment = "JOIN" - sqlSelect sqlSegment = "SELECT" - sqlDelete sqlSegment = "DELETE" - sqlInsert sqlSegment = "INSERT" - sqlUpdate sqlSegment = "UPDATE" - sqlFrom sqlSegment = "FROM" - sqlGroup sqlSegment = "GROUP BY" - sqlOrder sqlSegment = "ORDER BY" + sqlWhere sqlSegment = "WHERE" + sqlLimit sqlSegment = "LIMIT" + sqlOffset sqlSegment = "OFFSET" + sqlJoin sqlSegment = "JOIN" + sqlLeftJoin sqlSegment = "LEFT JOIN" + sqlRightJoin sqlSegment = "RIGHT JOIN" + sqlInnerJoin sqlSegment = "INNER JOIN" + sqlOuterJoin sqlSegment = "OUTER JOIN" + sqlSelect sqlSegment = "SELECT" + sqlDelete sqlSegment = "DELETE" + sqlInsert sqlSegment = "INSERT" + sqlUpdate sqlSegment = "UPDATE" + sqlFrom sqlSegment = "FROM" + sqlGroup sqlSegment = "GROUP BY" + sqlOrder sqlSegment = "ORDER BY" // SPECIAL CASES sqlInsertMulti sqlSegment = "INSERTM" )