1
0
mirror of https://github.com/SpaceVim/SpaceVim.git synced 2025-02-03 09:50:04 +08:00
SpaceVim/bundle/plenary.nvim/lua/luassert/assert.lua
2022-05-16 22:20:10 +08:00

183 lines
4.9 KiB
Lua
Vendored

local s = require 'say'
local astate = require 'luassert.state'
local util = require 'luassert.util'
local unpack = require 'luassert.compatibility'.unpack
local obj -- the returned module table
local level_mt = {}
-- list of namespaces
local namespace = require 'luassert.namespaces'
local function geterror(assertion_message, failure_message, args)
if util.hastostring(failure_message) then
failure_message = tostring(failure_message)
elseif failure_message ~= nil then
failure_message = astate.format_argument(failure_message)
end
local message = s(assertion_message, obj:format(args))
if message and failure_message then
message = failure_message .. "\n" .. message
end
return message or failure_message
end
local __state_meta = {
__call = function(self, ...)
local keys = util.extract_keys("assertion", self.tokens)
local assertion
for _, key in ipairs(keys) do
assertion = namespace.assertion[key] or assertion
end
if assertion then
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self)
end
end
local arguments = {...}
arguments.n = select('#', ...) -- add argument count for trailing nils
local val, retargs = assertion.callback(self, arguments, util.errorlevel())
if not val == self.mod then
local message = assertion.positive_message
if not self.mod then
message = assertion.negative_message
end
local err = geterror(message, rawget(self,"failure_message"), arguments)
error(err or "assertion failed!", util.errorlevel())
end
if retargs then
return unpack(retargs)
end
return ...
else
local arguments = {...}
arguments.n = select('#', ...)
self.tokens = {}
for _, key in ipairs(keys) do
if namespace.modifier[key] then
namespace.modifier[key].callback(self, arguments, util.errorlevel())
end
end
end
return self
end,
__index = function(self, key)
for token in key:lower():gmatch('[^_]+') do
table.insert(self.tokens, token)
end
return self
end
}
obj = {
state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end,
-- registers a function in namespace
register = function(self, nspace, name, callback, positive_message, negative_message)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = {
callback = callback,
name = lowername,
positive_message=positive_message,
negative_message=negative_message
}
end,
-- unregisters a function in a namespace
unregister = function(self, nspace, name)
local lowername = name:lower()
if not namespace[nspace] then
namespace[nspace] = {}
end
namespace[nspace][lowername] = nil
end,
-- registers a formatter
-- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument
add_formatter = function(self, callback)
astate.add_formatter(callback)
end,
-- unregisters a formatter
remove_formatter = function(self, fmtr)
astate.remove_formatter(fmtr)
end,
format = function(self, args)
-- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost
local nofmt = args.nofmt or {} -- arguments in this list should not be formatted
local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter
for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils
if not nofmt[i] then
local val = args[i]
local valfmt = astate.format_argument(val, nil, fmtargs[i])
if valfmt == nil then valfmt = tostring(val) end -- no formatter found
args[i] = valfmt
end
end
return args
end,
set_parameter = function(self, name, value)
astate.set_parameter(name, value)
end,
get_parameter = function(self, name)
return astate.get_parameter(name)
end,
add_spy = function(self, spy)
astate.add_spy(spy)
end,
snapshot = function(self)
return astate.snapshot()
end,
level = function(self, level)
return setmetatable({
level = level
}, level_mt)
end,
-- returns the level if a level-value, otherwise nil
get_level = function(self, level)
if getmetatable(level) ~= level_mt then
return nil -- not a valid error-level
end
return level.level
end,
}
local __meta = {
__call = function(self, bool, message, level, ...)
if not bool then
local err_level = (self:get_level(level) or 1) + 1
error(message or "assertion failed!", err_level)
end
return bool , message , level , ...
end,
__index = function(self, key)
return rawget(self, key) or self.state()[key]
end,
}
return setmetatable(obj, __meta)