Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle nested classes and nested methods #140

Merged
merged 2 commits into from
Jun 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
neovim_version: ['nightly', 'v0.6.1', 'v0.7.0']
neovim_version: ['nightly', 'v0.9.5', 'v0.10.0']

steps:
- uses: actions/checkout@v2
Expand Down
242 changes: 146 additions & 96 deletions lua/dap-python.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@ end
--- Built-in are test runners for unittest, pytest and django.
--- The key is the test runner name, the value a function to generate the
--- module name to run and its arguments. See |dap-python.TestRunner|
---@type table<string,TestRunner>
---@type table<string, TestRunner>
M.test_runners = {}

local function prune_nil(items)
return vim.tbl_filter(function(x) return x end, items)
end

local is_windows = function()
return vim.fn.has("win32") == 1
Expand Down Expand Up @@ -146,29 +143,51 @@ local function get_module_path()
end
end


---@return string[]
local function flatten(...)
local argc = select("#", ...)
local result = {}
for i = 1, argc do
local arg = select(i, ...)
if type(arg) == "table" then
vim.list_extend(result, arg)
else
table.insert(result, arg)
end
end
return result
end


---@private
function M.test_runners.unittest(classname, methodname)
local path = get_module_path()
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.unittest(classnames, methodname)
local test_path = table.concat(flatten(get_module_path(), classnames, methodname), '.')
local args = {'-v', test_path}
return 'unittest', args
end


---@private
function M.test_runners.pytest(classname, methodname)
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.pytest(classnames, methodname)
local path = vim.fn.expand('%:p')
local test_path = table.concat(prune_nil({path, classname, methodname}), '::')
local test_path = table.concat(flatten({path, classnames, methodname}), '::')
-- -s "allow output to stdout of test"
local args = {'-s', test_path}
return 'pytest', args
end


---@private
function M.test_runners.django(classname, methodname)
---@param classnames string[]|string
---@param methodname string?
function M.test_runners.django(classnames, methodname)
local path = get_module_path()
local test_path = table.concat(prune_nil({path, classname, methodname}), '.')
local test_path = table.concat(flatten({path, classnames, methodname}), '.')
local args = {'test', test_path}
return 'django', args
end
Expand Down Expand Up @@ -258,50 +277,10 @@ function M.setup(adapter_python_path, opts)
end


local function get_nodes(query_text, predicate)
local end_row = api.nvim_win_get_cursor(0)[1]
local ft = api.nvim_buf_get_option(0, 'filetype')
assert(ft == 'python', 'test_method of dap-python only works for python files, not ' .. ft)
local query = (vim.treesitter.query.parse
and vim.treesitter.query.parse(ft, query_text)
or vim.treesitter.parse_query(ft, query_text)
)
assert(query, 'Could not parse treesitter query. Cannot find test')
local parser = vim.treesitter.get_parser(0)
local root = (parser:parse()[1]):root()
local nodes = {}
for _, node in query:iter_captures(root, 0, 0, end_row) do
if predicate(node) then
table.insert(nodes, node)
end
end
return nodes
end


local function get_function_nodes()
local query_text = [[
(function_definition
name: (identifier) @name) @definition.function
]]
return get_nodes(query_text, function(node)
return node:type() == 'identifier'
end)
end


local function get_class_nodes()
local query_text = [[
(class_definition
name: (identifier) @name) @definition.class
]]
return get_nodes(query_text, function(node)
return node:type() == 'identifier'
end)
end


local function get_node_text(node)
if vim.treesitter.get_node_text then
return vim.treesitter.get_node_text(node, 0)
end
local row1, col1, row2, col2 = node:range()
if row1 == row2 then
row2 = row2 + 1
Expand All @@ -314,24 +293,90 @@ local function get_node_text(node)
end


local function get_parent_classname(node)
local parent = node:parent()
while parent do
local type = parent:type()
if type == 'class_definition' then
for child in parent:iter_children() do
if child:type() == 'identifier' then
return get_node_text(child)
end
--- Reverse list inline
---@param list any[]
local function reverse(list)
local len = #list
for i = 1, math.floor(len * 0.5) do
local opposite = len - i + 1
list[i], list[opposite] = list[opposite], list[i]
end
end


---@param source string|integer
---@param subject "function"|"class"
---@param end_row integer? defaults to cursor
---@return TSNode[]
function M._get_nodes(source, subject, end_row)
end_row = end_row or api.nvim_win_get_cursor(0)[1]
local query_text = [[
(function_definition
name: (identifier) @function
)
(class_definition
name: (identifier) @class
)
]]
local lang = "python"
local query = (vim.treesitter.query.parse
and vim.treesitter.query.parse(lang, query_text)
or vim.treesitter.parse_query(lang, query_text)
)
local parser = (
type(source) == "number"
and vim.treesitter.get_parser(source, lang)
or vim.treesitter.get_string_parser(source --[[@as string]], lang)
)
local trees = parser:parse()
local root = trees[1]:root()
local nodes = {}
for id, node in query:iter_captures(root, source, 0, end_row) do
local capture = query.captures[id]
if capture == subject then
table.insert(nodes, node)
end
end
if not next(nodes) then
return nodes
end
if subject == "function" then
local result = nodes[#nodes]
local parent = result
while parent ~= nil do
if parent:type() == "function_definition" then
local ident = parent:child(1)
assert(ident:type() == "identifier")
result = ident
end
parent = parent:parent()
end
parent = parent:parent()
return { result }
elseif subject == "class" then
local last = nodes[#nodes]
local parent = last
local results = {}
while parent ~= nil do
if parent:type() == "class_definition" then
local ident = parent:child(1)
assert(ident:type() == "identifier")
table.insert(results, ident)
end
parent = parent:parent()
end
reverse(results)
return results
else
error("Expected subject 'function' or 'class', not: " .. subject)
end
end


---@param classnames string[]
---@param methodname string?
---@param opts DebugOpts
local function trigger_test(classname, methodname, opts)
local function trigger_test(classnames, methodname, opts)
local test_runner = opts.test_runner or (M.test_runner or default_runner)
if type(test_runner) == "function" then
test_runner = test_runner()
Expand All @@ -342,9 +387,11 @@ local function trigger_test(classname, methodname, opts)
return
end
assert(type(runner) == "function", "Test runner must be a function")
local module, args = runner(classname, methodname)
-- for BWC with custom runners which expect a string instead of a list of strings
local classes = #classnames == 1 and classnames[1] or classnames
local module, args = runner(classes, methodname)
local config = {
name = table.concat(prune_nil({classname, methodname}), '.'),
name = table.concat(flatten(classnames, methodname), '.'),
type = 'python',
request = 'launch',
module = module,
Expand All @@ -355,49 +402,51 @@ local function trigger_test(classname, methodname, opts)
end


local function closest_above_cursor(nodes)
local result
for _, node in pairs(nodes) do
if not result then
result = node
else
local node_row1, _, _, _ = node:range()
local result_row1, _, _, _ = result:range()
if node_row1 > result_row1 then
result = node
end
end
end
return result
end


--- Run test class above cursor
---@param opts? DebugOpts See |dap-python.DebugOpts|
function M.test_class(opts)
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
local class_node = closest_above_cursor(get_class_nodes())
if not class_node then
print('No suitable test class found')
local candidates = M._get_nodes(0, "class")
if not candidates then
print('No test class found near cursor')
return
end
local class = get_node_text(class_node)
trigger_test(class, nil, opts)
local names = vim.tbl_map(get_node_text, candidates)
trigger_test(names, nil, opts)
end


---@param node TSNode
---@result TSNode[]
local function get_parent_classes(node)
local parent = node:parent()
local result = {}
while parent ~= nil do
if parent:type() == "class_definition" then
local ident = parent:child(1)
assert(ident and ident:type() == "identifier")
table.insert(result, ident)
end
parent = parent:parent()
end
reverse(result)
return result
end


--- Run the test method above cursor
---@param opts? DebugOpts See |dap-python.DebugOpts|
function M.test_method(opts)
opts = vim.tbl_extend('keep', opts or {}, default_test_opts)
local function_node = closest_above_cursor(get_function_nodes())
if not function_node then
print('No suitable test method found')
local functions = M._get_nodes(0, "function")
if not functions then
print('No test method found near cursor')
return
end
local class = get_parent_classname(function_node)
local function_name = get_node_text(function_node)
trigger_test(class, function_name, opts)
local fn = functions[1]
local parent_classes = get_parent_classes(fn)
local classnames = vim.tbl_map(get_node_text, parent_classes)
trigger_test(classnames, get_node_text(fn), opts)
end


Expand All @@ -414,6 +463,7 @@ local function remove_indent(lines)
end
end
if offset > 1 then
assert(offset)
return vim.tbl_map(function(x) return string.sub(x, offset) end, lines)
else
return lines
Expand Down Expand Up @@ -479,7 +529,7 @@ end
---@field pythonPath string|nil Path to python interpreter. Uses interpreter from `VIRTUAL_ENV` environment variable or `adapter_python_path` by default


---@alias TestRunner fun(classname: string, methodname: string):string, string[]
---@alias TestRunner fun(classname: string|string[], methodname: string?):string, string[]

---@alias DebugpyConsole "internalConsole"|"integratedTerminal"|"externalTerminal"|nil

Expand Down
Loading
Loading