diff --git a/lib/pg_query/fingerprint.rb b/lib/pg_query/fingerprint.rb index a8ecf95d..53230e60 100644 --- a/lib/pg_query/fingerprint.rb +++ b/lib/pg_query/fingerprint.rb @@ -72,9 +72,8 @@ def fingerprint_node(node, hash, parent_node_name = nil, parent_field_name = nil return if ignored_node_type?(node) if node.is_a?(Node) - return if node.node.nil? - node_val = node[node.node.to_s] - unless ignored_node_type?(node_val) + node_val = node.inner + unless node_val.nil? || ignored_node_type?(node_val) unless node_val.is_a?(List) postgres_node_name = node_protobuf_field_name_to_json_name(node.class, node.node) hash.update(postgres_node_name) diff --git a/lib/pg_query/node.rb b/lib/pg_query/node.rb index 74ee8c39..54fe5068 100644 --- a/lib/pg_query/node.rb +++ b/lib/pg_query/node.rb @@ -1,22 +1,27 @@ module PgQuery # Patch the auto-generated generic node type with additional convenience functions class Node + def self.inner_class_to_name(klass) + @inner_class_to_name ||= descriptor.lookup_oneof('node').to_h { |f| [f.subtype.msgclass, f.name.to_sym] } + @inner_class_to_name[klass] + end + + def inner + self[node.to_s] + end + + def inner=(submsg) + name = self.class.inner_class_to_name(submsg.class) + public_send("#{name}=", submsg) + end + def inspect - node ? format('', node, public_send(node).inspect) : '' + node ? format('', node, inner.inspect) : '' end # Make it easier to initialize nodes from a given node child object def self.from(node_field_val) - # This needs to match libpg_query naming for the Node message field names - # (see "underscore" method in libpg_query's scripts/generate_protobuf_and_funcs.rb) - node_field_name = node_field_val.class.name.split('::').last - node_field_name.gsub!(/^([A-Z\d])([A-Z][a-z])/, '\1__\2') - node_field_name.gsub!(/([A-Z\d]+[a-z]+)([A-Z][a-z])/, '\1_\2') - node_field_name.gsub!(/([a-z\d])([A-Z])/, '\1_\2') - node_field_name.tr!('-', '_') - node_field_name.downcase! - - PgQuery::Node.new(node_field_name => node_field_val) + PgQuery::Node.new(inner_class_to_name(node_field_val.class) => node_field_val) end # Make it easier to initialize value nodes diff --git a/lib/pg_query/parse.rb b/lib/pg_query/parse.rb index a920847c..0b0483bf 100644 --- a/lib/pg_query/parse.rb +++ b/lib/pg_query/parse.rb @@ -145,7 +145,7 @@ def load_objects! # rubocop:disable Metrics/CyclomaticComplexity end # The following statements modify the contents of a table when :insert_stmt, :update_stmt, :delete_stmt - value = statement.public_send(statement.node) + value = statement.inner from_clause_items << { item: PgQuery::Node.new(range_var: value.relation), type: :dml } statements << value.select_stmt if statement.node == :insert_stmt && value.select_stmt diff --git a/lib/pg_query/treewalker.rb b/lib/pg_query/treewalker.rb index 19f0253b..e62b5858 100644 --- a/lib/pg_query/treewalker.rb +++ b/lib/pg_query/treewalker.rb @@ -1,8 +1,21 @@ module PgQuery class ParserResult - def walk! - treewalker!(@tree) do |parent_node, parent_field, node, location| - yield(parent_node, parent_field, node, location) + # Walks the parse tree and calls the passed block for each contained node + # + # If you pass a block with 1 argument, you will get each node. + # If you pass a block with 4 arguments, you will get each parent_node, parent_field, node and location. + # + # Location uniquely identifies a given node within the parse tree. This is a stable identifier across + # multiple parser runs, assuming the same pg_query release and no modifications to the parse tree. + def walk!(&block) + if block.arity == 1 + treewalker!(@tree) do |_, _, node, _| + yield(node) + end + else + treewalker!(@tree) do |parent_node, parent_field, node, location| + yield(parent_node, parent_field, node, location) + end end end diff --git a/spec/lib/treewalker_spec.rb b/spec/lib/treewalker_spec.rb index b8ef7788..3b7870c3 100644 --- a/spec/lib/treewalker_spec.rb +++ b/spec/lib/treewalker_spec.rb @@ -33,4 +33,14 @@ [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :args, 0, :param_ref] ] end + + it 'allows recursively replacing nodes' do + query = PgQuery.parse("SELECT * FROM tbl WHERE col::text = ANY(((ARRAY[$39, $40])::varchar[])::text[])") + query.walk! do |node| + next unless node.is_a?(PgQuery::Node) + # Keep removing type casts until we hit a different class + node.inner = node.type_cast.arg.inner while node.node == :type_cast + end + expect(query.deparse).to eq 'SELECT * FROM tbl WHERE col = ANY(ARRAY[$39, $40])' + end end