Skip to content

Commit

Permalink
Handle nested classes and nested methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mfussenegger committed Jun 2, 2024
1 parent 3dffa58 commit aae4fa4
Show file tree
Hide file tree
Showing 2 changed files with 222 additions and 96 deletions.
242 changes: 146 additions & 96 deletions lua/dap-python.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ end
--- Built-in are test runners for unittest, pytest and django.
--- The key is the test runner name, the value a function to generate the
--- module name to run and its arguments. See |dap-python.TestRunner|
---@type table<string,TestRunner>
---@type table<string, TestRunner>
M.test_runners = {}

local function prune_nil(items)
return vim.tbl_filter(function(x) return x end, items)
end

local is_windows = function()
return vim.fn.has("win32") == 1
Expand Down Expand Up @@ -146,29 +143,51 @@ local function get_module_path()
end
end


---@return string[]
local function flatten(...)
local argc = select("#", ...)
local result = {}
for i = 1, argc do
local arg = select(i, ...)
if type(arg) == "table" then
vim.list_extend(result, arg)
else
table.insert(result, arg)
end
end
return result
end


---@private
function M.test_runners.unittest(classname, methodname)
local path = get_module_path()
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.unittest(classnames, methodname)
local test_path = table.concat(flatten(get_module_path(), classnames, methodname), '.')
local args = {'-v', test_path}
return 'unittest', args
end


---@private
function M.test_runners.pytest(classname, methodname)
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.pytest(classnames, methodname)
local path = vim.fn.expand('%:p')
local test_path = table.concat(prune_nil({path, classname, methodname}), '::')
local test_path = table.concat(flatten({path, classnames, methodname}), '::')
-- -s "allow output to stdout of test"
local args = {'-s', test_path}
return 'pytest', args
end


---@private
function M.test_runners.django(classname, methodname)
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.django(classnames, methodname)
local path = get_module_path()
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
local test_path = table.concat(flatten({path, classnames, methodname}), '.')
local args = {'test', test_path}
return 'django', args
end
Expand Down Expand Up @@ -258,50 +277,10 @@ function M.setup(adapter_python_path, opts)
end


local function get_nodes(query_text, predicate)
local end_row = api.nvim_win_get_cursor(0)[1]
local ft = api.nvim_buf_get_option(0, 'filetype')
assert(ft == 'python', 'test_method of dap-python only works for python files, not ' .. ft)
local query = (vim.treesitter.query.parse
and vim.treesitter.query.parse(ft, query_text)
or vim.treesitter.parse_query(ft, query_text)
)
assert(query, 'Could not parse treesitter query. Cannot find test')
local parser = vim.treesitter.get_parser(0)
local root = (parser:parse()[1]):root()
local nodes = {}
for _, node in query:iter_captures(root, 0, 0, end_row) do
if predicate(node) then
table.insert(nodes, node)
end
end
return nodes
end


local function get_function_nodes()
local query_text = [[
(function_definition
name: (identifier) @name) @definition.function
]]
return get_nodes(query_text, function(node)
return node:type() == 'identifier'
end)
end


local function get_class_nodes()
local query_text = [[
(class_definition
name: (identifier) @name) @definition.class
]]
return get_nodes(query_text, function(node)
return node:type() == 'identifier'
end)
end


local function get_node_text(node)
if vim.treesitter.get_node_text then
return vim.treesitter.get_node_text(node, 0)
end
local row1, col1, row2, col2 = node:range()
if row1 == row2 then
row2 = row2 + 1
Expand All @@ -314,24 +293,90 @@ local function get_node_text(node)
end


local function get_parent_classname(node)
local parent = node:parent()
while parent do
local type = parent:type()
if type == 'class_definition' then
for child in parent:iter_children() do
if child:type() == 'identifier' then
return get_node_text(child)
end
--- Reverse list inline
---@param list any[]
local function reverse(list)
local len = #list
for i = 1, math.floor(len * 0.5) do
local opposite = len - i + 1
list[i], list[opposite] = list[opposite], list[i]
end
end


---@param source string|integer
---@param subject "function"|"class"
---@param end_row integer? defaults to cursor
---@return TSNode[]
function M._get_nodes(source, subject, end_row)
end_row = end_row or api.nvim_win_get_cursor(0)[1]
local query_text = [[
(function_definition
name: (identifier) @function
)
(class_definition
name: (identifier) @class
)
]]
local lang = "python"
local query = (vim.treesitter.query.parse
and vim.treesitter.query.parse(lang, query_text)
or vim.treesitter.parse_query(lang, query_text)
)
local parser = (
type(source) == "number"
and vim.treesitter.get_parser(source, lang)
or vim.treesitter.get_string_parser(source --[[@as string]], lang)
)
local trees = parser:parse()
local root = trees[1]:root()
local nodes = {}
for id, node in query:iter_captures(root, source, 0, end_row) do
local capture = query.captures[id]
if capture == subject then
table.insert(nodes, node)
end
end
if not next(nodes) then
return nodes
end
if subject == "function" then
local result = nodes[#nodes]
local parent = result
while parent ~= nil do
if parent:type() == "function_definition" then
local ident = parent:child(1)
assert(ident:type() == "identifier")
result = ident
end
parent = parent:parent()
end
parent = parent:parent()
return { result }
elseif subject == "class" then
local last = nodes[#nodes]
local parent = last
local results = {}
while parent ~= nil do
if parent:type() == "class_definition" then
local ident = parent:child(1)
assert(ident:type() == "identifier")
table.insert(results, ident)
end
parent = parent:parent()
end
reverse(results)
return results
else
error("Expected subject 'function' or 'class', not: " .. subject)
end
end


---@param classnames string[]
---@param methodname string?
---@param opts DebugOpts
local function trigger_test(classname, methodname, opts)
local function trigger_test(classnames, methodname, opts)
local test_runner = opts.test_runner or (M.test_runner or default_runner)
if type(test_runner) == "function" then
test_runner = test_runner()
Expand All @@ -342,9 +387,11 @@ local function trigger_test(classname, methodname, opts)
return
end
assert(type(runner) == "function", "Test runner must be a function")
local module, args = runner(classname, methodname)
-- for BWC with custom runners which expect a string instead of a list of strings
local classes = #classnames == 1 and classnames[1] or classnames
local module, args = runner(classes, methodname)
local config = {
name = table.concat(prune_nil({classname, methodname}), '.'),
name = table.concat(flatten(classnames, methodname), '.'),
type = 'python',
request = 'launch',
module = module,
Expand All @@ -355,49 +402,51 @@ local function trigger_test(classname, methodname, opts)
end


local function closest_above_cursor(nodes)
local result
for _, node in pairs(nodes) do
if not result then
result = node
else
local node_row1, _, _, _ = node:range()
local result_row1, _, _, _ = result:range()
if node_row1 > result_row1 then
result = node
end
end
end
return result
end


--- Run test class above cursor
---@param opts? DebugOpts See |dap-python.DebugOpts|
function M.test_class(opts)
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
local class_node = closest_above_cursor(get_class_nodes())
if not class_node then
print('No suitable test class found')
local candidates = M._get_nodes(0, "class")
if not candidates then
print('No test class found near cursor')
return
end
local class = get_node_text(class_node)
trigger_test(class, nil, opts)
local names = vim.tbl_map(get_node_text, candidates)
trigger_test(names, nil, opts)
end


---@param node TSNode
---@result TSNode[]
local function get_parent_classes(node)
local parent = node:parent()
local result = {}
while parent ~= nil do
if parent:type() == "class_definition" then
local ident = parent:child(1)
assert(ident and ident:type() == "identifier")
table.insert(result, ident)
end
parent = parent:parent()
end
reverse(result)
return result
end


--- Run the test method above cursor
---@param opts? DebugOpts See |dap-python.DebugOpts|
function M.test_method(opts)
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
local function_node = closest_above_cursor(get_function_nodes())
if not function_node then
print('No suitable test method found')
local functions = M._get_nodes(0, "function")
if not functions then
print('No test method found near cursor')
return
end
local class = get_parent_classname(function_node)
local function_name = get_node_text(function_node)
trigger_test(class, function_name, opts)
local fn = functions[1]
local parent_classes = get_parent_classes(fn)
local classnames = vim.tbl_map(get_node_text, parent_classes)
trigger_test(classnames, get_node_text(fn), opts)
end


Expand All @@ -414,6 +463,7 @@ local function remove_indent(lines)
end
end
if offset > 1 then
assert(offset)
return vim.tbl_map(function(x) return string.sub(x, offset) end, lines)
else
return lines
Expand Down Expand Up @@ -479,7 +529,7 @@ end
---@field pythonPath string|nil Path to python interpreter. Uses interpreter from `VIRTUAL_ENV` environment variable or `adapter_python_path` by default


---@alias TestRunner fun(classname: string, methodname: string):string, string[]
---@alias TestRunner fun(classname: string|string[], methodname: string?):string, string[]

---@alias DebugpyConsole "internalConsole"|"integratedTerminal"|"externalTerminal"|nil

Expand Down
Loading

0 comments on commit aae4fa4

Please sign in to comment.