diff --git a/query_builder.go b/query_builder.go index 0310a0d..c2f4ec5 100644 --- a/query_builder.go +++ b/query_builder.go @@ -29,6 +29,8 @@ type QueryBuilder interface { BuildOrderByAndLimit(string, []string, int64, int64) string // BuildUnion generates a UNION clause from the given union information. BuildUnion([]UnionInfo, Params) string + // BuildFor generates a FOR UPDATE clause from the given expression + BuildFor(option string) string } // BaseQueryBuilder provides a basic implementation of QueryBuilder. @@ -234,6 +236,16 @@ func (q *BaseQueryBuilder) BuildLimit(limit int64, offset int64) string { return sql + fmt.Sprintf("OFFSET %v", offset) } +func (q *BaseQueryBuilder) BuildFor(option string) string { + s := "" + if option != "" { + s += option + } else { + s += "UPDATE " + } + return "FOR " + s +} + func (q *BaseQueryBuilder) quoteTableNameAndAlias(table string) string { matches := selectRegex.FindStringSubmatch(table) if len(matches) == 0 { diff --git a/select.go b/select.go index 8476fff..3937a75 100644 --- a/select.go +++ b/select.go @@ -21,19 +21,20 @@ type SelectQuery struct { builder Builder ctx context.Context - selects []string - distinct bool - selectOption string - from []string - where Expression - join []JoinInfo - orderBy []string - groupBy []string - having Expression - union []UnionInfo - limit int64 - offset int64 - params Params + selects []string + distinct bool + selectOption string + from []string + where Expression + join []JoinInfo + orderBy []string + groupBy []string + having Expression + union []UnionInfo + limit int64 + offset int64 + params Params + lockingOption string } // JoinInfo contains the specification for a JOIN clause. @@ -68,14 +69,14 @@ func NewSelectQuery(builder Builder, db *DB) *SelectQuery { } // Context returns the context associated with the query. -func (q *SelectQuery) Context() context.Context { - return q.ctx +func (s *SelectQuery) Context() context.Context { + return s.ctx } // WithContext associates a context with the query. -func (q *SelectQuery) WithContext(ctx context.Context) *SelectQuery { - q.ctx = ctx - return q +func (s *SelectQuery) WithContext(ctx context.Context) *SelectQuery { + s.ctx = ctx + return s } // Select specifies the columns to be selected. @@ -245,6 +246,12 @@ func (s *SelectQuery) AndBind(params Params) *SelectQuery { return s } +// For appends the FOR UPDATE or FOR SHARING clause +func (s *SelectQuery) For(option string) *SelectQuery { + s.lockingOption = option + return s +} + // Build builds the SELECT query and returns an executable Query object. func (s *SelectQuery) Build() *Query { params := Params{} @@ -261,6 +268,7 @@ func (s *SelectQuery) Build() *Query { qb.BuildWhere(s.where, params), qb.BuildGroupBy(s.groupBy), qb.BuildHaving(s.having, params), + qb.BuildFor(s.lockingOption), } sql := "" for _, clause := range clauses {