diff --git a/lib/pg_query/treewalker.rb b/lib/pg_query/treewalker.rb index 19f0253b..e62b5858 100644 --- a/lib/pg_query/treewalker.rb +++ b/lib/pg_query/treewalker.rb @@ -1,8 +1,21 @@ module PgQuery class ParserResult - def walk! - treewalker!(@tree) do |parent_node, parent_field, node, location| - yield(parent_node, parent_field, node, location) + # Walks the parse tree and calls the passed block for each contained node + # + # If you pass a block with 1 argument, you will get each node. + # If you pass a block with 4 arguments, you will get each parent_node, parent_field, node and location. + # + # Location uniquely identifies a given node within the parse tree. This is a stable identifier across + # multiple parser runs, assuming the same pg_query release and no modifications to the parse tree. + def walk!(&block) + if block.arity == 1 + treewalker!(@tree) do |_, _, node, _| + yield(node) + end + else + treewalker!(@tree) do |parent_node, parent_field, node, location| + yield(parent_node, parent_field, node, location) + end end end diff --git a/spec/lib/treewalker_spec.rb b/spec/lib/treewalker_spec.rb index 4a174e8b..3b7870c3 100644 --- a/spec/lib/treewalker_spec.rb +++ b/spec/lib/treewalker_spec.rb @@ -36,7 +36,7 @@ it 'allows recursively replacing nodes' do query = PgQuery.parse("SELECT * FROM tbl WHERE col::text = ANY(((ARRAY[$39, $40])::varchar[])::text[])") - query.walk! do |_parent_node, _parent_field, node, _location| + query.walk! do |node| next unless node.is_a?(PgQuery::Node) # Keep removing type casts until we hit a different class node.inner = node.type_cast.arg.inner while node.node == :type_cast