Skip to content

Commit

Permalink
Update for prism > 0.15.1
Browse files Browse the repository at this point in the history
  • Loading branch information
tompng committed Oct 28, 2023
1 parent 81a1dc9 commit 464cec7
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 15 deletions.
48 changes: 35 additions & 13 deletions lib/katakata_irb/type_analyzer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,12 @@ def evaluate_multi_target_node(node, scope)
KatakataIrb::Types::NIL
end

def evaluate_splat_node(node, scope)
# Raw SplatNode, incomplete code like `*a.`
evaluate_multi_write_receiver node.expression, scope, nil if node.expression
KatakataIrb::Types::NIL
end

def evaluate_implicit_node(node, scope)
evaluate node.value, scope
end
Expand Down Expand Up @@ -838,15 +844,21 @@ def const_path_write(receiver, name, value, scope)
end

def assign_required_parameter(node, value, scope)
case node
when Prism::RequiredParameterNode
case node.type
when :required_parameter_node
scope[node.name.to_s] = value || KatakataIrb::Types::OBJECT
when Prism::RequiredDestructuredParameterNode
when :required_destructured_parameter_node # Revmoed in prism > 0.15.1
values = value ? sized_splat(value, :to_ary, node.parameters.size) : []
node.parameters.zip values do |n, v|
assign_required_parameter n, v, scope
end
when Prism::SplatNode
when :multi_target_node # Added to parameters in prism > 0.15.1
parameters = [*node.lefts, *node.rest, *node.rights]
values = value ? sized_splat(value, :to_ary, parameters.size) : []
parameters.zip values do |n, v|
assign_required_parameter n, v, scope
end
when :splat_node
splat_value = value ? KatakataIrb::Types.array_of(value) : KatakataIrb::Types::ARRAY
assign_required_parameter node.expression, splat_value, scope
end
Expand Down Expand Up @@ -886,7 +898,7 @@ def assign_parameters(node, scope, args, kwargs)
args = sized_splat(args.first, :to_ary, size) if size >= 2 && args.size == 1
reqs = args.shift node.requireds.size
if node.rest
# node.rest.class is Prism::RestParameterNode
# node.rest is Prism::RestParameterNode
posts = []
opts = args.shift node.optionals.size
rest = args
Expand Down Expand Up @@ -975,7 +987,13 @@ def evaluate_match_pattern(value, pattern, scope)
KatakataIrb::Types::ARRAY
when Prism::HashPatternNode
# TODO
pattern.assocs.each { evaluate_match_pattern KatakataIrb::Types::OBJECT, _1, scope }
# assocs changed to elements in prism > 0.15.1
elements = pattern.respond_to?(:assocs) ? pattern.assocs : pattern.elements
elements.each { evaluate_match_pattern KatakataIrb::Types::OBJECT, _1, scope }
if pattern.respond_to?(:rest) && pattern.rest # prism > 0.15.1
# pattern.rest was included in pattern.assocs until prism <= 0.15.1
evaluate_match_pattern KatakataIrb::Types::OBJECT, pattern.rest, scope
end
KatakataIrb::Types::HASH
when Prism::AssocNode
evaluate_match_pattern value, pattern.value, scope if pattern.value
Expand Down Expand Up @@ -1032,18 +1050,20 @@ def evaluate_write(node, value, scope, evaluated_receivers)
end

def evaluate_multi_write(node, values, scope, evaluated_receivers)
values = sized_splat values, :to_ary, node.targets.size unless values.is_a? Array
splat_index = node.targets.find_index { _1.is_a? Prism::SplatNode }
# prism <= 0.15.1 has targets, prism > 0.15.1 has lefts, rest, rights
targets = node.respond_to?(:targets) ? node.targets : [*node.lefts, *node.rest, *node.rights]
values = sized_splat values, :to_ary, targets.size unless values.is_a? Array
splat_index = targets.find_index { _1.is_a? Prism::SplatNode }
if splat_index
pre_targets = node.targets[0...splat_index]
splat_target = node.targets[splat_index]
post_targets = node.targets[splat_index + 1..]
pre_targets = targets[0...splat_index]
splat_target = targets[splat_index]
post_targets = targets[splat_index + 1..]
pre_values = values.shift pre_targets.size
post_values = values.pop post_targets.size
splat_value = KatakataIrb::Types::UnionType[*values]
zips = pre_targets.zip(pre_values) + [[splat_target, splat_value]] + post_targets.zip(post_values)
else
zips = node.targets.zip(values)
zips = targets.zip(values)
end
zips.each do |target, value|
evaluate_write target, value || KatakataIrb::Types::NIL, scope, evaluated_receivers
Expand All @@ -1053,7 +1073,9 @@ def evaluate_multi_write(node, values, scope, evaluated_receivers)
def evaluate_multi_write_receiver(node, scope, evaluated_receivers)
case node
when Prism::MultiWriteNode, Prism::MultiTargetNode
node.targets.each { evaluate_multi_write_receiver _1, scope, evaluated_receivers }
# prism <= 0.15.1 has targets, prism > 0.15.1 has lefts, rest, rights
targets = node.respond_to?(:targets) ? node.targets : [*node.lefts, *node.rest, *node.rights]
targets.each { evaluate_multi_write_receiver _1, scope, evaluated_receivers }
when Prism::CallNode
if node.receiver
receiver = evaluate(node.receiver, scope)
Expand Down
10 changes: 9 additions & 1 deletion test/test_katakata_irb.rb
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,17 @@ def test_prism_node_names
codes = files.map do |file|
File.read File.join(File.dirname(__FILE__), '../lib/katakata_irb', file)
end
ignore_class_names = ['Prism::BlockLocalVariableNode', 'Prism::IndexAndWriteNode', 'Prism::IndexOperatorWriteNode', 'Prism::IndexOrWriteNode']
ignore_class_names = [
# Not traversed
'Prism::BlockLocalVariableNode',
# Added in prism 0.15.0
'Prism::IndexAndWriteNode', 'Prism::IndexOperatorWriteNode', 'Prism::IndexOrWriteNode',
# Removed in prism > 0.15.1
'Prism::RequiredDestructuredParameterNode'
]
implemented_node_class_names = [
*codes.join.scan(/evaluate_[a-z_]+/).grep(/_node$/).map { "Prism::#{_1[9..].split('_').map(&:capitalize).join}" },
*codes.join.scan(/:[a-z_]+_node/).map { "Prism::#{_1[1..].split('_').map(&:capitalize).join}" },
*codes.join.scan(/Prism::[A-Za-z]+Node/)
].uniq.sort - ignore_class_names
all_node_class_names = Prism.constants.grep(/Node$/).map { "Prism::#{_1}" }.sort - ['Prism::Node'] - ignore_class_names
Expand Down
2 changes: 1 addition & 1 deletion test/test_type_analyze.rb
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_def
assert_call('def (a="").f; end; a.', include: String)
assert_call('def f(a=1); a.', include: Integer)
assert_call('def f(**nil); 1.', include: Integer)
assert_call('def f(a,*b); *b.', include: Array)
assert_call('def f(a,*b); b.', include: Array)
assert_call('def f(a,x:1); x.', include: Integer)
assert_call('def f(a,x:,**); 1.', include: Integer)
assert_call('def f(a,x:,**y); y.', include: Hash)
Expand Down

0 comments on commit 464cec7

Please sign in to comment.