mirror of
https://github.com/SpaceVim/SpaceVim.git
synced 2025-02-03 13:10:05 +08:00
365 lines
9.6 KiB
Lua
365 lines
9.6 KiB
Lua
-- Functions to handle locals
|
|
-- Locals are a generalization of definition and scopes
|
|
-- its the way nvim-treesitter uses to "understand" the code
|
|
|
|
local queries = require "nvim-treesitter.query"
|
|
local ts_utils = require "nvim-treesitter.ts_utils"
|
|
local ts = vim.treesitter
|
|
local api = vim.api
|
|
|
|
local M = {}
|
|
|
|
function M.collect_locals(bufnr)
|
|
return queries.collect_group_results(bufnr, "locals")
|
|
end
|
|
|
|
-- Iterates matches from a locals query file.
|
|
-- @param bufnr the buffer
|
|
-- @param root the root node
|
|
function M.iter_locals(bufnr, root)
|
|
return queries.iter_group_results(bufnr, "locals", root)
|
|
end
|
|
|
|
---@param bufnr integer
|
|
---@return any
|
|
function M.get_locals(bufnr)
|
|
return queries.get_matches(bufnr, "locals")
|
|
end
|
|
|
|
-- Creates unique id for a node based on text and range
|
|
---@param scope TSNode: the scope node of the definition
|
|
---@param node_text string: the node text to use
|
|
---@return string: a string id
|
|
function M.get_definition_id(scope, node_text)
|
|
-- Add a valid starting character in case node text doesn't start with a valid one.
|
|
return table.concat({ "k", node_text or "", scope:range() }, "_")
|
|
end
|
|
|
|
function M.get_definitions(bufnr)
|
|
local locals = M.get_locals(bufnr)
|
|
|
|
local defs = {}
|
|
|
|
for _, loc in ipairs(locals) do
|
|
if loc.definition then
|
|
table.insert(defs, loc.definition)
|
|
end
|
|
end
|
|
|
|
return defs
|
|
end
|
|
|
|
function M.get_scopes(bufnr)
|
|
local locals = M.get_locals(bufnr)
|
|
|
|
local scopes = {}
|
|
|
|
for _, loc in ipairs(locals) do
|
|
if loc.scope and loc.scope.node then
|
|
table.insert(scopes, loc.scope.node)
|
|
end
|
|
end
|
|
|
|
return scopes
|
|
end
|
|
|
|
function M.get_references(bufnr)
|
|
local locals = M.get_locals(bufnr)
|
|
|
|
local refs = {}
|
|
|
|
for _, loc in ipairs(locals) do
|
|
if loc.reference and loc.reference.node then
|
|
table.insert(refs, loc.reference.node)
|
|
end
|
|
end
|
|
|
|
return refs
|
|
end
|
|
|
|
-- Gets a table with all the scopes containing a node
|
|
-- The order is from most specific to least (bottom up)
|
|
---@param node TSNode
|
|
---@param bufnr integer
|
|
---@return TSNode[]
|
|
function M.get_scope_tree(node, bufnr)
|
|
local scopes = {} ---@type TSNode[]
|
|
|
|
for scope in M.iter_scope_tree(node, bufnr) do
|
|
table.insert(scopes, scope)
|
|
end
|
|
|
|
return scopes
|
|
end
|
|
|
|
-- Iterates over a nodes scopes moving from the bottom up
|
|
---@param node TSNode
|
|
---@param bufnr integer
|
|
---@return fun(): TSNode|nil
|
|
function M.iter_scope_tree(node, bufnr)
|
|
local last_node = node
|
|
return function()
|
|
if not last_node then
|
|
return
|
|
end
|
|
|
|
local scope = M.containing_scope(last_node, bufnr, false) or ts_utils.get_root_for_node(node)
|
|
|
|
last_node = scope:parent()
|
|
|
|
return scope
|
|
end
|
|
end
|
|
|
|
-- Gets a table of all nodes and their 'kinds' from a locals list
|
|
---@param local_def any: the local list result
|
|
---@return table: a list of node entries
|
|
function M.get_local_nodes(local_def)
|
|
local result = {}
|
|
|
|
M.recurse_local_nodes(local_def, function(def, _node, kind)
|
|
table.insert(result, vim.tbl_extend("keep", { kind = kind }, def))
|
|
end)
|
|
|
|
return result
|
|
end
|
|
|
|
-- Recurse locals results until a node is found.
|
|
-- The accumulator function is given
|
|
-- * The table of the node
|
|
-- * The node
|
|
-- * The full definition match `@definition.var.something` -> 'var.something'
|
|
-- * The last definition match `@definition.var.something` -> 'something'
|
|
---@param local_def any The locals result
|
|
---@param accumulator function The accumulator function
|
|
---@param full_match? string The full match path to append to
|
|
---@param last_match? string The last match
|
|
function M.recurse_local_nodes(local_def, accumulator, full_match, last_match)
|
|
if type(local_def) ~= "table" then
|
|
return
|
|
end
|
|
|
|
if local_def.node then
|
|
accumulator(local_def, local_def.node, full_match, last_match)
|
|
else
|
|
for match_key, def in pairs(local_def) do
|
|
M.recurse_local_nodes(def, accumulator, full_match and (full_match .. "." .. match_key) or match_key, match_key)
|
|
end
|
|
end
|
|
end
|
|
|
|
-- Get a single dimension table to look definition nodes.
|
|
-- Keys are generated by using the range of the containing scope and the text of the definition node.
|
|
-- This makes looking up a definition for a given scope a simple key lookup.
|
|
--
|
|
-- This is memoized by buffer tick. If the function is called in succession
|
|
-- without the buffer tick changing, then the previous result will be used
|
|
-- since the syntax tree hasn't changed.
|
|
--
|
|
-- Usage lookups require finding the definition of the node, so `find_definition`
|
|
-- is called very frequently, which is why this lookup must be fast as possible.
|
|
--
|
|
---@param bufnr integer: the buffer
|
|
---@return table result: a table for looking up definitions
|
|
M.get_definitions_lookup_table = ts_utils.memoize_by_buf_tick(function(bufnr)
|
|
local definitions = M.get_definitions(bufnr)
|
|
local result = {}
|
|
|
|
for _, definition in ipairs(definitions) do
|
|
for _, node_entry in ipairs(M.get_local_nodes(definition)) do
|
|
local scopes = M.get_definition_scopes(node_entry.node, bufnr, node_entry.scope)
|
|
-- Always use the highest valid scope
|
|
local scope = scopes[#scopes]
|
|
local node_text = ts.get_node_text(node_entry.node, bufnr)
|
|
local id = M.get_definition_id(scope, node_text)
|
|
|
|
result[id] = node_entry
|
|
end
|
|
end
|
|
|
|
return result
|
|
end)
|
|
|
|
-- Gets all the scopes of a definition based on the scope type
|
|
-- Scope types can be
|
|
--
|
|
-- "parent": Uses the parent of the containing scope, basically, skipping a scope
|
|
-- "global": Uses the top most scope
|
|
-- "local": Uses the containing scope of the definition. This is the default
|
|
--
|
|
---@param node TSNode: the definition node
|
|
---@param bufnr integer: the buffer
|
|
---@param scope_type string: the scope type
|
|
function M.get_definition_scopes(node, bufnr, scope_type)
|
|
local scopes = {}
|
|
local scope_count = 1 ---@type integer|nil
|
|
|
|
-- Definition is valid for the containing scope
|
|
-- and the containing scope of that scope
|
|
if scope_type == "parent" then
|
|
scope_count = 2
|
|
-- Definition is valid in all parent scopes
|
|
elseif scope_type == "global" then
|
|
scope_count = nil
|
|
end
|
|
|
|
local i = 0
|
|
for scope in M.iter_scope_tree(node, bufnr) do
|
|
table.insert(scopes, scope)
|
|
i = i + 1
|
|
|
|
if scope_count and i >= scope_count then
|
|
break
|
|
end
|
|
end
|
|
|
|
return scopes
|
|
end
|
|
|
|
---@param node TSNode
|
|
---@param bufnr integer
|
|
---@return TSNode node
|
|
---@return TSNode scope
|
|
---@return string|nil kind
|
|
function M.find_definition(node, bufnr)
|
|
local def_lookup = M.get_definitions_lookup_table(bufnr)
|
|
local node_text = ts.get_node_text(node, bufnr)
|
|
|
|
for scope in M.iter_scope_tree(node, bufnr) do
|
|
local id = M.get_definition_id(scope, node_text)
|
|
|
|
if def_lookup[id] then
|
|
local entry = def_lookup[id]
|
|
|
|
return entry.node, scope, entry.kind
|
|
end
|
|
end
|
|
|
|
return node, ts_utils.get_root_for_node(node), nil
|
|
end
|
|
|
|
-- Finds usages of a node in a given scope.
|
|
---@param node TSNode the node to find usages for
|
|
---@param scope_node TSNode the node to look within
|
|
---@return TSNode[]: a list of nodes
|
|
function M.find_usages(node, scope_node, bufnr)
|
|
bufnr = bufnr or api.nvim_get_current_buf()
|
|
local node_text = ts.get_node_text(node, bufnr)
|
|
|
|
if not node_text or #node_text < 1 then
|
|
return {}
|
|
end
|
|
|
|
local scope_node = scope_node or ts_utils.get_root_for_node(node)
|
|
local usages = {}
|
|
|
|
for match in M.iter_locals(bufnr, scope_node) do
|
|
if match.reference and match.reference.node and ts.get_node_text(match.reference.node, bufnr) == node_text then
|
|
local def_node, _, kind = M.find_definition(match.reference.node, bufnr)
|
|
|
|
if kind == nil or def_node == node then
|
|
table.insert(usages, match.reference.node)
|
|
end
|
|
end
|
|
end
|
|
|
|
return usages
|
|
end
|
|
|
|
---@param node TSNode
|
|
---@param bufnr? integer
|
|
---@param allow_scope? boolean
|
|
---@return TSNode|nil
|
|
function M.containing_scope(node, bufnr, allow_scope)
|
|
local bufnr = bufnr or api.nvim_get_current_buf()
|
|
local allow_scope = allow_scope == nil or allow_scope == true
|
|
|
|
local scopes = M.get_scopes(bufnr)
|
|
if not node or not scopes then
|
|
return
|
|
end
|
|
|
|
local iter_node = node
|
|
|
|
while iter_node ~= nil and not vim.tbl_contains(scopes, iter_node) do
|
|
iter_node = iter_node:parent()
|
|
end
|
|
|
|
return iter_node or (allow_scope and node or nil)
|
|
end
|
|
|
|
function M.nested_scope(node, cursor_pos)
|
|
local bufnr = api.nvim_get_current_buf()
|
|
|
|
local scopes = M.get_scopes(bufnr)
|
|
if not node or not scopes then
|
|
return
|
|
end
|
|
|
|
local row = cursor_pos.row ---@type integer
|
|
local col = cursor_pos.col ---@type integer
|
|
local scope = M.containing_scope(node)
|
|
|
|
for _, child in ipairs(ts_utils.get_named_children(scope)) do
|
|
local row_, col_ = child:start()
|
|
if vim.tbl_contains(scopes, child) and ((row_ + 1 == row and col_ > col) or row_ + 1 > row) then
|
|
return child
|
|
end
|
|
end
|
|
end
|
|
|
|
function M.next_scope(node)
|
|
local bufnr = api.nvim_get_current_buf()
|
|
|
|
local scopes = M.get_scopes(bufnr)
|
|
if not node or not scopes then
|
|
return
|
|
end
|
|
|
|
local scope = M.containing_scope(node)
|
|
|
|
local parent = scope:parent()
|
|
if not parent then
|
|
return
|
|
end
|
|
|
|
local is_prev = true
|
|
for _, child in ipairs(ts_utils.get_named_children(parent)) do
|
|
if child == scope then
|
|
is_prev = false
|
|
elseif not is_prev and vim.tbl_contains(scopes, child) then
|
|
return child
|
|
end
|
|
end
|
|
end
|
|
|
|
---@param node TSNode
|
|
---@return TSNode|nil
|
|
function M.previous_scope(node)
|
|
local bufnr = api.nvim_get_current_buf()
|
|
|
|
local scopes = M.get_scopes(bufnr)
|
|
if not node or not scopes then
|
|
return
|
|
end
|
|
|
|
local scope = M.containing_scope(node)
|
|
|
|
local parent = scope:parent()
|
|
if not parent then
|
|
return
|
|
end
|
|
|
|
local is_prev = true
|
|
local children = ts_utils.get_named_children(parent)
|
|
for i = #children, 1, -1 do
|
|
if children[i] == scope then
|
|
is_prev = false
|
|
elseif not is_prev and vim.tbl_contains(scopes, children[i]) then
|
|
return children[i]
|
|
end
|
|
end
|
|
end
|
|
|
|
return M
|