-- module will not return anything, only register assertions with the main assert engine -- assertions take 2 parameters; -- 1) state -- 2) arguments list. The list has a member 'n' with the argument count to check for trailing nils -- 3) level The level of the error position relative to the called function -- returns; boolean; whether assertion passed local assert = require('luassert.assert') local astate = require ('luassert.state') local util = require ('luassert.util') local s = require('say') local function format(val) return astate.format_argument(val) or tostring(val) end local function set_failure_message(state, message) if message ~= nil then state.failure_message = message end end local function unique(state, arguments, level) local list = arguments[1] local deep local argcnt = arguments.n if type(arguments[2]) == "boolean" or (arguments[2] == nil and argcnt > 2) then deep = arguments[2] set_failure_message(state, arguments[3]) else if type(arguments[3]) == "boolean" then deep = arguments[3] end set_failure_message(state, arguments[2]) end for k,v in pairs(list) do for k2, v2 in pairs(list) do if k ~= k2 then if deep and util.deepcompare(v, v2, true) then return false else if v == v2 then return false end end end end end return true end local function near(state, arguments, level) local level = (level or 1) + 1 local argcnt = arguments.n assert(argcnt > 2, s("assertion.internal.argtolittle", { "near", 3, tostring(argcnt) }), level) local expected = tonumber(arguments[1]) local actual = tonumber(arguments[2]) local tolerance = tonumber(arguments[3]) local numbertype = "number or object convertible to a number" assert(expected, s("assertion.internal.badargtype", { 1, "near", numbertype, format(arguments[1]) }), level) assert(actual, s("assertion.internal.badargtype", { 2, "near", numbertype, format(arguments[2]) }), level) assert(tolerance, s("assertion.internal.badargtype", { 3, "near", numbertype, format(arguments[3]) }), level) -- switch arguments for proper output message util.tinsert(arguments, 1, util.tremove(arguments, 2)) arguments[3] = tolerance arguments.nofmt = arguments.nofmt or {} arguments.nofmt[3] = true set_failure_message(state, arguments[4]) return (actual >= expected - tolerance and actual <= expected + tolerance) end local function matches(state, arguments, level) local level = (level or 1) + 1 local argcnt = arguments.n assert(argcnt > 1, s("assertion.internal.argtolittle", { "matches", 2, tostring(argcnt) }), level) local pattern = arguments[1] local actual = nil if util.hastostring(arguments[2]) or type(arguments[2]) == "number" then actual = tostring(arguments[2]) end local err_message local init_arg_num = 3 for i=3,argcnt,1 do if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then if i == 3 then init_arg_num = init_arg_num + 1 end err_message = util.tremove(arguments, i) break end end local init = arguments[3] local plain = arguments[4] local stringtype = "string or object convertible to a string" assert(type(pattern) == "string", s("assertion.internal.badargtype", { 1, "matches", "string", type(arguments[1]) }), level) assert(actual, s("assertion.internal.badargtype", { 2, "matches", stringtype, format(arguments[2]) }), level) assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level) -- switch arguments for proper output message util.tinsert(arguments, 1, util.tremove(arguments, 2)) set_failure_message(state, err_message) local retargs local ok if plain then ok = (actual:find(pattern, init, plain) ~= nil) retargs = ok and { pattern } or {} else retargs = { actual:match(pattern, init) } ok = (retargs[1] ~= nil) end return ok, retargs end local function equals(state, arguments, level) local level = (level or 1) + 1 local argcnt = arguments.n assert(argcnt > 1, s("assertion.internal.argtolittle", { "equals", 2, tostring(argcnt) }), level) local result = arguments[1] == arguments[2] -- switch arguments for proper output message util.tinsert(arguments, 1, util.tremove(arguments, 2)) set_failure_message(state, arguments[3]) return result end local function same(state, arguments, level) local level = (level or 1) + 1 local argcnt = arguments.n assert(argcnt > 1, s("assertion.internal.argtolittle", { "same", 2, tostring(argcnt) }), level) if type(arguments[1]) == 'table' and type(arguments[2]) == 'table' then local result, crumbs = util.deepcompare(arguments[1], arguments[2], true) -- switch arguments for proper output message util.tinsert(arguments, 1, util.tremove(arguments, 2)) arguments.fmtargs = arguments.fmtargs or {} arguments.fmtargs[1] = { crumbs = crumbs } arguments.fmtargs[2] = { crumbs = crumbs } set_failure_message(state, arguments[3]) return result end local result = arguments[1] == arguments[2] -- switch arguments for proper output message util.tinsert(arguments, 1, util.tremove(arguments, 2)) set_failure_message(state, arguments[3]) return result end local function truthy(state, arguments, level) set_failure_message(state, arguments[2]) return arguments[1] ~= false and arguments[1] ~= nil end local function falsy(state, arguments, level) return not truthy(state, arguments, level) end local function has_error(state, arguments, level) local level = (level or 1) + 1 local retargs = util.shallowcopy(arguments) local func = arguments[1] local err_expected = arguments[2] local failure_message = arguments[3] assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error", "function or callable object", type(func) }), level) local ok, err_actual = pcall(func) if type(err_actual) == 'string' then -- remove 'path/to/file:line: ' from string err_actual = err_actual:gsub('^.-:%d+: ', '', 1) end retargs[1] = err_actual arguments.nofmt = {} arguments.n = 2 arguments[1] = (ok and '(no error)' or err_actual) arguments[2] = (err_expected == nil and '(error)' or err_expected) arguments.nofmt[1] = ok arguments.nofmt[2] = (err_expected == nil) set_failure_message(state, failure_message) if ok or err_expected == nil then return not ok, retargs end if type(err_expected) == 'string' then -- err_actual must be (convertible to) a string if util.hastostring(err_actual) then err_actual = tostring(err_actual) retargs[1] = err_actual end if type(err_actual) == 'string' then return err_expected == err_actual, retargs end elseif type(err_expected) == 'number' then if type(err_actual) == 'string' then return tostring(err_expected) == tostring(tonumber(err_actual)), retargs end end return same(state, {err_expected, err_actual, ["n"] = 2}), retargs end local function error_matches(state, arguments, level) local level = (level or 1) + 1 local retargs = util.shallowcopy(arguments) local argcnt = arguments.n local func = arguments[1] local pattern = arguments[2] assert(argcnt > 1, s("assertion.internal.argtolittle", { "error_matches", 2, tostring(argcnt) }), level) assert(util.callable(func), s("assertion.internal.badargtype", { 1, "error_matches", "function or callable object", type(func) }), level) assert(pattern == nil or type(pattern) == "string", s("assertion.internal.badargtype", { 2, "error", "string", type(pattern) }), level) local failure_message local init_arg_num = 3 for i=3,argcnt,1 do if arguments[i] and type(arguments[i]) ~= "boolean" and not tonumber(arguments[i]) then if i == 3 then init_arg_num = init_arg_num + 1 end failure_message = util.tremove(arguments, i) break end end local init = arguments[3] local plain = arguments[4] assert(init == nil or tonumber(init), s("assertion.internal.badargtype", { init_arg_num, "matches", "number", type(arguments[3]) }), level) local ok, err_actual = pcall(func) if type(err_actual) == 'string' then -- remove 'path/to/file:line: ' from string err_actual = err_actual:gsub('^.-:%d+: ', '', 1) end retargs[1] = err_actual arguments.nofmt = {} arguments.n = 2 arguments[1] = (ok and '(no error)' or err_actual) arguments[2] = pattern arguments.nofmt[1] = ok arguments.nofmt[2] = false set_failure_message(state, failure_message) if ok then return not ok, retargs end if err_actual == nil and pattern == nil then return true, {} end -- err_actual must be (convertible to) a string if util.hastostring(err_actual) then err_actual = tostring(err_actual) retargs[1] = err_actual end if type(err_actual) == 'string' then local ok local retargs_ok if plain then retargs_ok = { pattern } ok = (err_actual:find(pattern, init, plain) ~= nil) else retargs_ok = { err_actual:match(pattern, init) } ok = (retargs_ok[1] ~= nil) end if ok then retargs = retargs_ok end return ok, retargs end return false, retargs end local function is_true(state, arguments, level) util.tinsert(arguments, 2, true) set_failure_message(state, arguments[3]) return arguments[1] == arguments[2] end local function is_false(state, arguments, level) util.tinsert(arguments, 2, false) set_failure_message(state, arguments[3]) return arguments[1] == arguments[2] end local function is_type(state, arguments, level, etype) util.tinsert(arguments, 2, "type " .. etype) arguments.nofmt = arguments.nofmt or {} arguments.nofmt[2] = true set_failure_message(state, arguments[3]) return arguments.n > 1 and type(arguments[1]) == etype end local function returned_arguments(state, arguments, level) arguments[1] = tostring(arguments[1]) arguments[2] = tostring(arguments.n - 1) arguments.nofmt = arguments.nofmt or {} arguments.nofmt[1] = true arguments.nofmt[2] = true if arguments.n < 2 then arguments.n = 2 end return arguments[1] == arguments[2] end local function set_message(state, arguments, level) state.failure_message = arguments[1] end local function is_boolean(state, arguments, level) return is_type(state, arguments, level, "boolean") end local function is_number(state, arguments, level) return is_type(state, arguments, level, "number") end local function is_string(state, arguments, level) return is_type(state, arguments, level, "string") end local function is_table(state, arguments, level) return is_type(state, arguments, level, "table") end local function is_nil(state, arguments, level) return is_type(state, arguments, level, "nil") end local function is_userdata(state, arguments, level) return is_type(state, arguments, level, "userdata") end local function is_function(state, arguments, level) return is_type(state, arguments, level, "function") end local function is_thread(state, arguments, level) return is_type(state, arguments, level, "thread") end assert:register("modifier", "message", set_message) assert:register("assertion", "true", is_true, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "false", is_false, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "boolean", is_boolean, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "number", is_number, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "string", is_string, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "table", is_table, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "nil", is_nil, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "userdata", is_userdata, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "function", is_function, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "thread", is_thread, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "returned_arguments", returned_arguments, "assertion.returned_arguments.positive", "assertion.returned_arguments.negative") assert:register("assertion", "same", same, "assertion.same.positive", "assertion.same.negative") assert:register("assertion", "matches", matches, "assertion.matches.positive", "assertion.matches.negative") assert:register("assertion", "match", matches, "assertion.matches.positive", "assertion.matches.negative") assert:register("assertion", "near", near, "assertion.near.positive", "assertion.near.negative") assert:register("assertion", "equals", equals, "assertion.equals.positive", "assertion.equals.negative") assert:register("assertion", "equal", equals, "assertion.equals.positive", "assertion.equals.negative") assert:register("assertion", "unique", unique, "assertion.unique.positive", "assertion.unique.negative") assert:register("assertion", "error", has_error, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "errors", has_error, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "error_matches", error_matches, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "error_match", error_matches, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "matches_error", error_matches, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "match_error", error_matches, "assertion.error.positive", "assertion.error.negative") assert:register("assertion", "truthy", truthy, "assertion.truthy.positive", "assertion.truthy.negative") assert:register("assertion", "falsy", falsy, "assertion.falsy.positive", "assertion.falsy.negative")