Skip to content

Commit

Permalink
Fix bug of [*] [**] {**} (*)
Browse files Browse the repository at this point in the history
  • Loading branch information
tompng committed Nov 2, 2023
1 parent 2c24335 commit 21cd581
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
11 changes: 7 additions & 4 deletions lib/katakata_irb/type_analyzer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def evaluate_hash(node, scope)
keys << evaluate(assoc.key, scope)
values << evaluate(assoc.value, scope)
when Prism::AssocSplatNode
next unless assoc.value # def f(**); {**}

hash = evaluate assoc.value, scope
unless hash.is_a?(KatakataIrb::Types::InstanceType) && hash.klass == Hash
hash = method_call hash, :to_hash, [], nil, nil, scope
Expand Down Expand Up @@ -461,7 +463,6 @@ def evaluate_multi_write_node(node, scope)
elsif node.value
evaluate node.value, scope
else
# For syntax invalid code like `(*a).b`
KatakataIrb::Types::NIL
end
)
Expand Down Expand Up @@ -854,7 +855,7 @@ def assign_required_parameter(node, value, scope)
end
when :splat_node
splat_value = value ? KatakataIrb::Types.array_of(value) : KatakataIrb::Types::ARRAY
assign_required_parameter node.expression, splat_value, scope
assign_required_parameter node.expression, splat_value, scope if node.expression
end
end

Expand Down Expand Up @@ -1032,7 +1033,7 @@ def evaluate_write(node, value, scope, evaluated_receivers)
when Prism::CallNode
evaluated_receivers&.[](node.receiver) || evaluate(node.receiver, scope) if node.receiver
when Prism::SplatNode
evaluate_write node.expression, KatakataIrb::Types.array_of(value), scope, evaluated_receivers
evaluate_write node.expression, KatakataIrb::Types.array_of(value), scope, evaluated_receivers if node.expression
when Prism::LocalVariableTargetNode, Prism::GlobalVariableTargetNode, Prism::InstanceVariableTargetNode, Prism::ClassVariableTargetNode, Prism::ConstantTargetNode
scope[node.name.to_s] = value
when Prism::ConstantPathTargetNode
Expand Down Expand Up @@ -1083,13 +1084,15 @@ def evaluate_multi_write_receiver(node, scope, evaluated_receivers)
def evaluate_list_splat_items(list, scope)
items = list.flat_map do |node|
if node.is_a? Prism::SplatNode
next unless node.expression # def f(*); [*]

splat = evaluate node.expression, scope
array_elem, non_array = partition_to_array splat.nonnillable, :to_a
[*array_elem, *non_array]
else
evaluate node, scope
end
end.uniq
end.compact.uniq
KatakataIrb::Types::UnionType[*items]
end

Expand Down
7 changes: 7 additions & 0 deletions test/test_type_analyze.rb
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def test_massign
assert_call('a,*b=[1,2]; a.', include: Integer, exclude: Array)
assert_call('a,*b=[1,2]; b.', include: Array, exclude: Integer)
assert_call('a,*b=[1,2]; b.sample.', include: Integer)
assert_call('a,*,(*)=[1,2]; a.', include: Integer)
assert_call('*a=[1,2]; a.', include: Array, exclude: Integer)
assert_call('*a=[1,2]; a.sample.', include: Integer)
assert_call('a,*b,c=[1,2,3]; b.', include: Array, exclude: Integer)
Expand Down Expand Up @@ -356,6 +357,7 @@ def test_def
assert_call('def (a="").f; end; a.', include: String)
assert_call('def f(a=1); a.', include: Integer)
assert_call('def f(**nil); 1.', include: Integer)
assert_call('def f((*),*); 1.', include: Integer)
assert_call('def f(a,*b); b.', include: Array)
assert_call('def f(a,x:1); x.', include: Integer)
assert_call('def f(a,x:,**); 1.', include: Integer)
Expand All @@ -366,6 +368,8 @@ def test_def
assert_call('def f(a,...); 1.', include: Integer)
assert_call('def f(...); g(...); 1.', include: Integer)
assert_call('def f(*,**,&); g(*,**,&); 1.', include: Integer)
assert_call('def f(*,**,&); {**}.', include: Hash)
assert_call('def f(*,**,&); [*,**].', include: Array)
assert_call('class Array; def f; self.', include: Array)
end

Expand Down Expand Up @@ -465,6 +469,7 @@ def test_literal
assert_call('true.', include: TrueClass)
assert_call('false.', include: FalseClass)
assert_call('nil.', include: NilClass)
assert_call('().', include: NilClass)
assert_call('//.', include: Regexp)
assert_call('/#{a=1}/.', include: Regexp)
assert_call('/#{a=1}/; a.', include: Integer)
Expand Down Expand Up @@ -569,6 +574,7 @@ def test_while_until
def test_for
assert_call('for i in [1,2,3]; i.', include: Integer)
assert_call('for i,j in [1,2,3]; i.', include: Integer)
assert_call('for *,(*) in [1,2,3]; 1.', include: Integer)
assert_call('for *i in [1,2,3]; i.sample.', include: Integer)
assert_call('for (a=1).b in [1,2,3]; a.', include: Integer)
assert_call('for Array::B in [1,2,3]; Array::B.', include: Integer)
Expand Down Expand Up @@ -674,6 +680,7 @@ def test_block_args
assert_call('[1,2,3].tap{|a,*b| b.', include: Array)
assert_call('[1,2,3].tap{|a=1.0| a.', include: [Array, Float])
assert_call('[1,2,3].tap{|a,**b| b.', include: Hash)
assert_call('1.tap{|(*),*,**| 1.', include: Integer)
end

def test_array_aref
Expand Down

0 comments on commit 21cd581

Please sign in to comment.