diff --git a/lib/baby_squeel/active_record/where_chain.rb b/lib/baby_squeel/active_record/where_chain.rb index b2d1723..fa58384 100644 --- a/lib/baby_squeel/active_record/where_chain.rb +++ b/lib/baby_squeel/active_record/where_chain.rb @@ -6,7 +6,13 @@ module WhereChain # Constructs Arel for ActiveRecord::Base#where using the DSL. def has(&block) arel = DSL.evaluate(@scope, &block) - @scope.where!(arel) unless arel.blank? + + unless arel.blank? + arel = Operators::Grouping.coerce_boolean_attribute("where.has", arel) + + @scope.where!(arel) + end + @scope end end diff --git a/lib/baby_squeel/nodes/attribute.rb b/lib/baby_squeel/nodes/attribute.rb index db07f53..260d467 100644 --- a/lib/baby_squeel/nodes/attribute.rb +++ b/lib/baby_squeel/nodes/attribute.rb @@ -9,6 +9,14 @@ def initialize(parent, name) super(parent._table[@name]) end + def &(other) + boolean_binary_operator(:&, :and, other) + end + + def |(other) + boolean_binary_operator(:|, :or, other) + end + def in(rel) if rel.is_a? ::ActiveRecord::Relation Nodes.wrap ::Arel::Nodes::In.new(self, sanitize_relation(rel)) @@ -35,6 +43,13 @@ def _arel private + def boolean_binary_operator(operator, arel_method, other) + lhs = Operators::Grouping.coerce_boolean_attribute(operator, self) + rhs = Operators::Grouping.coerce_boolean_attribute(operator, other) + + lhs.send(arel_method, rhs) + end + # NullRelation must be treated as a special case, because # NullRelation#to_sql returns an empty string. As such, # we need to convert the NullRelation to regular relation. diff --git a/lib/baby_squeel/operators.rb b/lib/baby_squeel/operators.rb index 3ee0720..53a88d6 100644 --- a/lib/baby_squeel/operators.rb +++ b/lib/baby_squeel/operators.rb @@ -52,8 +52,24 @@ def op(operator, other) module Grouping extend ArelAliasing - arel_alias :&, :and - arel_alias :|, :or + + def self.coerce_boolean_attribute(op, node) + return node unless node.is_a?(Arel::Attributes::Attribute) + + unless node.type_caster.type == :boolean + raise ArgumentError, "non-boolean attribute #{node.name} passed to #{op}" + end + + Arel::Nodes::Equality.new(node, Arel::Nodes::True.new) + end + + def &(other) + self.and(Grouping.coerce_boolean_attribute(:&, other)) + end + + def |(other) + self.or(Grouping.coerce_boolean_attribute(:|, other)) + end end module Matching diff --git a/spec/integration/__snapshots__/where_chain_spec.yaml b/spec/integration/__snapshots__/where_chain_spec.yaml index edce3c7..3a7a406 100644 --- a/spec/integration/__snapshots__/where_chain_spec.yaml +++ b/spec/integration/__snapshots__/where_chain_spec.yaml @@ -61,3 +61,19 @@ "posts"."author_id" = 42 "#where.has wheres an association using #!= 1": SELECT "posts".* FROM "posts" WHERE "posts"."author_id" != 42 +"#where.has when using a boolean column coerces a plain boolean column reference to equality at the top-level 1": SELECT + "authors".* FROM "authors" WHERE "authors"."ugly" = 1 +"#where.has when using a boolean column coerces a plain boolean column reference to equality on the LHS of an AND 1": SELECT + "authors".* FROM "authors" WHERE "authors"."ugly" = 1 AND "authors"."id" = 1 +"#where.has when using a boolean column coerces a plain boolean column reference to equality on the RHS of an AND 1": SELECT + "authors".* FROM "authors" WHERE "authors"."id" = 1 AND "authors"."ugly" = 1 +"#where.has when using a boolean column coerces a negated plain boolean column reference to equality at the top-level 1": SELECT + "authors".* FROM "authors" +"#where.has when using a boolean column coerces a plain boolean column reference to equality on the LHS of an OR 1": SELECT + "authors".* FROM "authors" WHERE ("authors"."ugly" = 1 OR "authors"."id" = 1) +"#where.has when using a boolean column coerces a plain boolean column reference to equality on the RHS of an OR 1": SELECT + "authors".* FROM "authors" WHERE ("authors"."id" = 1 OR "authors"."ugly" = 1) +"#where.has when using a boolean column coerces a plain boolean column reference to equality on both sides of an AND 1": SELECT + "authors".* FROM "authors" WHERE "authors"."ugly" = 1 AND "authors"."ugly" = 1 +"#where.has when using a boolean column coerces a plain boolean column reference to equality on both sides of an OR 1": SELECT + "authors".* FROM "authors" WHERE ("authors"."ugly" = 1 OR "authors"."ugly" = 1) diff --git a/spec/integration/where_chain_spec.rb b/spec/integration/where_chain_spec.rb index 489fba5..c54d21f 100644 --- a/spec/integration/where_chain_spec.rb +++ b/spec/integration/where_chain_spec.rb @@ -205,6 +205,64 @@ expect(bs.to_sql).to eq(ar.to_sql) end + + context 'when using a boolean column' do + it 'coerces a plain boolean column reference to equality at the top-level' do + expect(Author.where.has { ugly }).to match_sql_snapshot + end + + it 'raises with a plain non-boolean column reference at the top-level' do + expect { Author.where.has { id } }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on the LHS of an AND' do + expect(Author.where.has { ugly & (id == 1) }).to match_sql_snapshot + end + + it 'raises with a plain non-boolean column reference on the LHS of an AND' do + expect { Author.where.has { id & (id == 1)} }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on the RHS of an AND' do + expect(Author.where.has { (id == 1) & ugly }).to match_sql_snapshot + end + + it 'raises with a plain non-boolean column reference on the RHS of an AND' do + expect { Author.where.has { (id == 1) & id } }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on both sides of an AND' do + expect(Author.where.has { ugly & ugly }).to match_sql_snapshot + end + + it 'raises with plain column references on both sides of an AND where only one is a boolean' do + expect { Author.where.has { ugly & id } }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on the LHS of an OR' do + expect(Author.where.has { ugly | (id == 1) }).to match_sql_snapshot + end + + it 'raises with a plain non-boolean column reference on the LHS of an OR' do + expect { Author.where.has { id | (id == 1)} }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on the RHS of an OR' do + expect(Author.where.has { (id == 1) | ugly }).to match_sql_snapshot + end + + it 'raises with a plain non-boolean column reference on the RHS of an OR' do + expect { Author.where.has { (id == 1) | id } }.to raise_error(ArgumentError) + end + + it 'coerces a plain boolean column reference to equality on both sides of an OR' do + expect(Author.where.has { ugly | ugly }).to match_sql_snapshot + end + + it 'raises with plain column references on both sides of an OR where only one is a boolean' do + expect { Author.where.has { ugly | id } }.to raise_error(ArgumentError) + end + end end describe '#where_values_hash' do