diff --git a/session_cond.go b/session_cond.go index 948a90bc1..697751872 100644 --- a/session_cond.go +++ b/session_cond.go @@ -64,7 +64,7 @@ func (session *Session) NotIn(column string, args ...interface{}) *Session { return session } -// Conds returns session query conditions +// Conds returns session query conditions except auto bean conditions func (session *Session) Conds() builder.Cond { return session.Statement.cond } diff --git a/session_cond_test.go b/session_cond_test.go index d90fbc2f0..7b9e8a07f 100644 --- a/session_cond_test.go +++ b/session_cond_test.go @@ -260,3 +260,35 @@ func TestIn(t *testing.T) { panic(err) } } + +func TestFindAndCount(t *testing.T) { + assert.NoError(t, prepareEngine()) + + type FindAndCount struct { + Id int64 + Name string + } + + assert.NoError(t, testEngine.Sync2(new(FindAndCount))) + + _, err := testEngine.Insert([]FindAndCount{ + { + Name: "test1", + }, + { + Name: "test2", + }, + }) + assert.NoError(t, err) + + var results []FindAndCount + sess := testEngine.Where("name = ?", "test1") + conds := sess.Conds() + err = sess.Find(&results) + assert.NoError(t, err) + assert.EqualValues(t, 1, len(results)) + + total, err := testEngine.Where(conds).Count(new(FindAndCount)) + assert.NoError(t, err) + assert.EqualValues(t, 1, total) +} diff --git a/session_find.go b/session_find.go index b711991e1..2518af42d 100644 --- a/session_find.go +++ b/session_find.go @@ -119,7 +119,8 @@ func (session *Session) Find(rowsSlicePtr interface{}, condiBean ...interface{}) } } - condSQL, condArgs, err := builder.ToSQL(session.Statement.cond.And(autoCond)) + session.Statement.cond = session.Statement.cond.And(autoCond) + condSQL, condArgs, err := builder.ToSQL(session.Statement.cond) if err != nil { return err } diff --git a/statement.go b/statement.go index 6e360bb37..0f90f64d7 100644 --- a/statement.go +++ b/statement.go @@ -890,17 +890,24 @@ func (statement *Statement) buildConds(table *core.Table, bean interface{}, incl statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName) } -func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { +func (statement *Statement) mergeConds(bean interface{}) error { if !statement.noAutoCondition { var addedTableName = (len(statement.JoinStr) > 0) autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName) if err != nil { - return "", nil, err + return err } statement.cond = statement.cond.And(autoCond) } if err := statement.processIDParam(); err != nil { + return err + } + return nil +} + +func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) { + if err := statement.mergeConds(bean); err != nil { return "", nil, err } @@ -940,14 +947,12 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, columnStr = "*" } - var condSQL string - var condArgs []interface{} - var err error if isStruct { - condSQL, condArgs, err = statement.genConds(bean) - } else { - condSQL, condArgs, err = builder.ToSQL(statement.cond) + if err := statement.mergeConds(bean); err != nil { + return "", nil, err + } } + condSQL, condArgs, err := builder.ToSQL(statement.cond) if err != nil { return "", nil, err }