local action_mt = {} --- Checks all replacement combinations to determine which function to run. --- If no replacement can be found, then it will run the original function local run_replace_or_original = function(replacements, original_func, ...) for _, replacement_map in ipairs(replacements or {}) do for condition, replacement in pairs(replacement_map) do if condition == true or condition(...) then return replacement(...) end end end return original_func(...) end local append_action_copy = function(new, v, old) table.insert(new, v) new._func[v] = old._func[v] new._static_pre[v] = old._static_pre[v] new._pre[v] = old._pre[v] new._replacements[v] = old._replacements[v] new._static_post[v] = old._static_post[v] new._post[v] = old._post[v] end -- TODO(conni2461): Not a fan of this solution/hack. Needs to be addressed local all_mts = {} --TODO(conni2461): It gets worse. This is so bad but because we have now n mts for n actions -- We have to check all actions for relevant mts to set replace and before, after -- Its not bad for performance because its being called on startup when we attach mappings. -- Its just a bad solution local find_all_relevant_mts = function(action_name, f) for _, mt in ipairs(all_mts) do for fun, _ in pairs(mt._func) do if fun == action_name then f(mt) end end end end --- an action is metatable which allows replacement(prepend or append) of the function ---@class Action ---@field _func table: the original action function ---@field _static_pre table: will allways run before the function even if its replaced ---@field _pre table: the functions that will run before the action ---@field _replacements table: the function that replaces this action ---@field _static_post table: will allways run after the function even if its replaced ---@field _post table: the functions that will run after the action action_mt.create = function() local mt = { __call = function(t, ...) local values = {} for _, action_name in ipairs(t) do if t._static_pre[action_name] then t._static_pre[action_name](...) end if vim.tbl_isempty(t._replacements) and t._pre[action_name] then t._pre[action_name](...) end local result = { run_replace_or_original(t._replacements[action_name], t._func[action_name], ...), } for _, res in ipairs(result) do table.insert(values, res) end if t._static_post[action_name] then t._static_post[action_name](...) end if vim.tbl_isempty(t._replacements) and t._post[action_name] then t._post[action_name](...) end end return unpack(values) end, __add = function(lhs, rhs) local new_action = setmetatable({}, action_mt.create()) for _, v in ipairs(lhs) do append_action_copy(new_action, v, lhs) end for _, v in ipairs(rhs) do append_action_copy(new_action, v, rhs) end new_action.clear = function() lhs.clear() rhs.clear() end return new_action end, _func = {}, _static_pre = {}, _pre = {}, _replacements = {}, _static_post = {}, _post = {}, } mt.__index = mt mt.clear = function() mt._pre = {} mt._replacements = {} mt._post = {} end --- Replace the reference to the function with a new one temporarily function mt:replace(v) assert(#self == 1, "Cannot replace an already combined action") return self:replace_map { [true] = v } end function mt:replace_if(condition, replacement) assert(#self == 1, "Cannot replace an already combined action") return self:replace_map { [condition] = replacement } end --- Replace table with -- Example: -- -- actions.select:replace_map { -- [function() return filetype == 'lua' end] = actions.file_split, -- [function() return filetype == 'other' end] = actions.file_split_edit, -- } function mt:replace_map(tbl) assert(#self == 1, "Cannot replace an already combined action") local action_name = self[1] find_all_relevant_mts(action_name, function(another) if not another._replacements[action_name] then another._replacements[action_name] = {} end table.insert(another._replacements[action_name], 1, tbl) end) return self end function mt:enhance(opts) assert(#self == 1, "Cannot enhance already combined actions") local action_name = self[1] find_all_relevant_mts(action_name, function(another) if opts.pre then another._pre[action_name] = opts.pre end if opts.post then another._post[action_name] = opts.post end end) return self end table.insert(all_mts, mt) return mt end action_mt.transform = function(k, mt, _, v) local res = setmetatable({ k }, mt) if type(v) == "table" then res._static_pre[k] = v.pre res._static_post[k] = v.post res._func[k] = v.action else res._func[k] = v end return res end action_mt.transform_mod = function(mod) -- Pass the metatable of the module if applicable. -- This allows for custom errors, lookups, etc. local redirect = setmetatable({}, getmetatable(mod) or {}) for k, v in pairs(mod) do local mt = action_mt.create() redirect[k] = action_mt.transform(k, mt, _, v) end redirect._clear = function() for k, v in pairs(redirect) do if k ~= "_clear" then pcall(v.clear) end end end return redirect end action_mt.clear_all = function() for _, v in ipairs(all_mts) do pcall(v.clear) end end return action_mt