diff --git a/lib/pg_query/treewalker.rb b/lib/pg_query/treewalker.rb index 79848b08..19f0253b 100644 --- a/lib/pg_query/treewalker.rb +++ b/lib/pg_query/treewalker.rb @@ -20,13 +20,18 @@ def treewalker!(tree) # rubocop:disable Metrics/CyclomaticComplexity node = parent_node[parent_field.to_s] next if node.nil? location = parent_location + [parent_field] - yield(parent_node, parent_field, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) nodes << [node, location] unless node.nil? end when Google::Protobuf::RepeatedField - nodes += parent_node.map.with_index { |e, idx| [e, parent_location + [idx]] } + parent_node.each_with_index do |node, parent_field| + next if node.nil? + location = parent_location + [parent_field] + yield(parent_node, parent_field, node, location) if node.is_a?(Google::Protobuf::MessageExts) || node.is_a?(Google::Protobuf::RepeatedField) + + nodes << [node, location] unless node.nil? + end end break if nodes.empty? diff --git a/spec/lib/treewalker_spec.rb b/spec/lib/treewalker_spec.rb new file mode 100644 index 00000000..b8ef7788 --- /dev/null +++ b/spec/lib/treewalker_spec.rb @@ -0,0 +1,36 @@ +require 'spec_helper' + +describe PgQuery, '.treewalker' do + it 'walks nodes contained in repeated fields' do + locations = [] + described_class.parse("SELECT to_timestamp($1)").walk! do |_, _, _, location| + locations << location + end + expect(locations).to match_array [ + [:stmts], + [:stmts, 0], + [:stmts, 0, :stmt], + [:stmts, 0, :stmt, :select_stmt], + [:stmts, 0, :stmt, :select_stmt, :distinct_clause], + [:stmts, 0, :stmt, :select_stmt, :target_list], + [:stmts, 0, :stmt, :select_stmt, :from_clause], + [:stmts, 0, :stmt, :select_stmt, :group_clause], + [:stmts, 0, :stmt, :select_stmt, :window_clause], + [:stmts, 0, :stmt, :select_stmt, :values_lists], + [:stmts, 0, :stmt, :select_stmt, :sort_clause], + [:stmts, 0, :stmt, :select_stmt, :locking_clause], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :indirection], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :funcname], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :args], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :agg_order], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :funcname, 0], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :args, 0], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :funcname, 0, :string], + [:stmts, 0, :stmt, :select_stmt, :target_list, 0, :res_target, :val, :func_call, :args, 0, :param_ref] + ] + end +end