From 7a7acfce921c9b5a3f5ad50e16aa2069e555026d Mon Sep 17 00:00:00 2001 From: Kyle McLamb Date: Sat, 21 Jul 2018 13:15:30 -0400 Subject: [PATCH] Add smarter return support We now track multireturns, which is like a "yeah dummy" kind of feature to be missing, and we do a better job of handling multiple return sites with the same signature, which seems to happen for reasons. I'm not really sure what past me had in mind when they created the "Literal" tag but I think this is the right way to handle it in the function sig code. --- lua-lsp/analyze.lua | 19 ++- lua-lsp/methods.lua | 66 +++++---- spec/definition_spec.lua | 33 +++++ spec/hover_spec.lua | 289 +++++++++++++++++++++++++++++++++++++++ spec/log_spec.lua | 32 ++++- spec/unicode_spec.lua | 4 + 6 files changed, 398 insertions(+), 45 deletions(-) create mode 100644 spec/hover_spec.lua diff --git a/lua-lsp/analyze.lua b/lua-lsp/analyze.lua index ef15ffd..d1191a9 100644 --- a/lua-lsp/analyze.lua +++ b/lua-lsp/analyze.lua @@ -310,7 +310,7 @@ local function gen_scopes(len, ast, uri) end end - local function save_return(a, expr) + local function save_return(a, return_node) -- move the return value up to the closest enclosing scope local mt repeat @@ -320,7 +320,11 @@ local function gen_scopes(len, ast, uri) setmetatable(a, mt) until mt.origin mt._return = mt._return or {} - table.insert(mt._return, clean_value(expr)) + local cleaned_exprs = {} + for _, return_expr in ipairs(return_node) do + table.insert(cleaned_exprs, clean_value(return_expr)) + end + table.insert(mt._return, cleaned_exprs) end local function visit_expr(node, a) @@ -411,17 +415,10 @@ local function gen_scopes(len, ast, uri) end end elseif node.tag == "Return" then - local exprlist = node[1] - if exprlist and exprlist.tag then - local expr = exprlist + for _, expr in ipairs(node) do visit_expr(expr, a) - save_return(a, expr) - elseif exprlist then - for _, expr in ipairs(exprlist) do - visit_expr(expr, a) - save_return(a, expr) - end end + save_return(a, node) elseif node.tag == "Local" then local namelist,exprlist = node[1], node[2] if exprlist then diff --git a/lua-lsp/methods.lua b/lua-lsp/methods.lua index c645013..bb75e24 100644 --- a/lua-lsp/methods.lua +++ b/lua-lsp/methods.lua @@ -134,6 +134,17 @@ local function merge_(a, b) for k, v in pairs(b) do a[k] = v end end +local function deduplicate_(tbl) + local used = {} + for i=#tbl, 1, -1 do + if used[tbl[i]] then + table.remove(tbl, i) + else + used[tbl[i]] = true + end + end +end + -- this is starting to get silly. local function make_items(k, val, isVariant, isInvoke) local item = { label = k } @@ -197,39 +208,36 @@ local function make_items(k, val, isVariant, isInvoke) end local ret = "" - local literals = { - String = "string", Number = "number", True = "bool", - False = "bool", Nil = "nil" - } if val.scope then local scope_mt = getmetatable(val.scope) if scope_mt._return then - local types, values, noValues = {}, {}, false - for _, r in ipairs(scope_mt._return) do - if literals[r.tag] then - table.insert(types, literals[r.tag]) - if not r[1] then - noValues = true + local sites = {} + for _, site in ipairs(scope_mt._return) do + local types, values, noValues = {}, {}, false + for _, r in ipairs(site) do + if r.tag == "Literal" then + table.insert(types, string.lower(r.tag)) + table.insert(values, string.lower(r.value)) elseif r.tag == "String" then - table.insert(values, string.format("%q", r[1])) - elseif r.tag == "Number" then - table.insert(values, tostring(r[1])) + table.insert(types, "string") + table.insert(values, string.format("%q", r.value)) + elseif r.tag == "Id" then + table.insert(types, r[1]) + noValues = true + else + -- not useful types + --table.insert(types, r.tag) + noValues = true end - elseif r.tag == "Id" then - table.insert(types, r[1]) - noValues = true + end + if noValues then + table.insert(sites, table.concat(types, ", ")) else - -- not useful types - --table.insert(types, r.tag) - noValues = true + table.insert(sites, table.concat(values, ", ")) end end - if noValues then - ret = table.concat(types, " | ") - else - ret = table.concat(values, "|") - end - --ret = "?" + deduplicate_(sites) + ret = table.concat(sites, " | ") end elseif val.returns then ret = {} @@ -387,12 +395,12 @@ local function getp(doc, t, k, isDefinition) if value.tag == "Require" then -- Resolve the return value of this module - local ref = analyze.module(value.module) + local ref = assert(analyze.module(value.module)) doc = ref if ref then -- start at file scope local mt = ref.scopes and getmetatable(ref.scopes[2]) - local ret = mt and mt._return and mt._return[1] + local ret = mt and mt._return and mt._return[1][1] if ret and ret.tag == "Id" then local _ _, value, doc = definition_of(ref, ret) @@ -412,9 +420,9 @@ local function getp(doc, t, k, isDefinition) key, v, doc = definition_of(doc, value.ref) if v.scope then local mt = v.scope and getmetatable(v.scope) - local rets = mt and mt._return or {} + local rets = mt and mt._return or {{}} --for _, ret in ipairs(rets) do - local ret = rets[1] + local ret = rets[1][1] -- overload. FIXME: this mutates the original which does -- not make sense if its a copy if ret.scope then diff --git a/spec/definition_spec.lua b/spec/definition_spec.lua index 1f7d733..eb46928 100644 --- a/spec/definition_spec.lua +++ b/spec/definition_spec.lua @@ -309,4 +309,37 @@ return a.jeff end) end) end) + + it("handles missing documents ", function() + mock_loop(function(rpc) + local text = [[ +local a = require'fake1' +return a.jeff +]] + local doc = { + uri = "file:///tmp/fake2.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + + --[[ FIXME: This currently produces an error, which is bad. + rpc.request("textDocument/definition", { + textDocument = doc, + position = {line = 1, character = 10} -- a.jeff + }, function(out) + assert.equal(doc1.uri, out.uri) + assert.same({line=1, character=2}, out.range.start) + end) + --]] + + rpc.request("textDocument/definition", { + textDocument = doc, + position = {line = 1, character = 8} -- a + }, function(out) + assert.equal(doc.uri, out.uri) + assert.same({line=0, character=6}, out.range.start) + end) + end) + end) end) diff --git a/spec/hover_spec.lua b/spec/hover_spec.lua new file mode 100644 index 0000000..6fc45dc --- /dev/null +++ b/spec/hover_spec.lua @@ -0,0 +1,289 @@ +local mock_loop = require 'spec.mock_loop' + +describe("textDocument/hover", function() + it("handles string returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return "hi" +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> "hi"\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles number returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return 42 +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> 42\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles true returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return true +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> true\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles false returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return false +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> false\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles nil returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return nil +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> nil\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles named returns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + local mystring = "hi" + return mystring +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> mystring\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles multireturns", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + return 1, 2, 3 +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> 1, 2, 3\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("handles multireturns with multiple sites", function() + mock_loop(function(rpc) + local text = [[ +local function myfun() + if ok then + return ok + else + return nil, "oops" + end +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'myfun() -> ok | nil, "oops"\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("deduplicates function return names", function() + mock_loop(function(rpc) + local text = [[ +local function branchy() + local myvar + if true then + return myvar + else + return myvar + end +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'branchy() -> myvar\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) + + it("deduplicates function return literals", function() + mock_loop(function(rpc) + local text = [[ +local function branchy() + local myvar + if true then + return "yeah" + else + return "yeah" + end + return "nah" +end +]] + local doc = { + uri = "file:///tmp/fake.lua" + } + rpc.notify("textDocument/didOpen", { + textDocument = {uri = doc.uri, text = text} + }) + local callme + rpc.request("textDocument/hover", { + textDocument = doc, + position = {line = 0, character = 16} + }, function(out) + assert.same({ + contents = {'branchy() -> "yeah" | "nah"\n'} + }, out) + callme = true + end) + assert.truthy(callme) + end) + end) +end) diff --git a/spec/log_spec.lua b/spec/log_spec.lua index 063c362..b3dfc41 100644 --- a/spec/log_spec.lua +++ b/spec/log_spec.lua @@ -14,11 +14,20 @@ describe("log.fmt", function() log.fmt("%_ %_", t1, t2)) end) it("handles %t", function() - assert.equal('{ "a", "b", "c" }', - log.fmt("%t", {"a", "b", "c"})) - - assert.equal('{ 1,\n = {}\n} { 2 }', - log.fmt("%t %t", setmetatable({1}, {}), {2})) + assert.equal( + '{ "a", "b", "c" }', + log.fmt("%t", {"a", "b", "c"})) + + assert.equal( + '{ 1,\n = {}\n} { 2 }', + log.fmt("%t %t", setmetatable({1}, {}), {2})) + + local totable = {totable = function() + return {13} + end} + assert.equal( + '12 { 13 }', + log.fmt("%t %t", 12, totable)) end) it("handles numeric args", function() assert.equal("12 nil", @@ -91,4 +100,17 @@ describe("log levels", function() assert.stub(log.file.write).was.called(5) end) + + it("can fatal error", function() + log.setTraceLevel("verbose") + log.file = {write = function() end} + io.write = log.file.write + stub(log.file, "write") + + assert.has_error(function() + log.fatal("e") + end) + + assert.stub(log.file.write).was.called(1) + end) end) diff --git a/spec/unicode_spec.lua b/spec/unicode_spec.lua index c96fd50..0d0c0df 100644 --- a/spec/unicode_spec.lua +++ b/spec/unicode_spec.lua @@ -61,5 +61,9 @@ describe("utf8 bytes <-> utf16 code units", function() s = "🤔🤔🤔🤔" assert.equal(9, unicode.to_bytes(s, 4)) assert.equal(4, unicode.to_codeunits(s, 9)) + + s = "ℤ is the set of integers" + assert.equal(6, unicode.to_bytes(s, 3)) + assert.equal(3, unicode.to_codeunits(s, 6)) end) end)