local s = require 'say'
local astate = require 'luassert.state'
local util = require 'luassert.util'
local unpack = require 'luassert.compatibility'.unpack
local obj   -- the returned module table
local level_mt = {}

-- list of namespaces
local namespace = require 'luassert.namespaces'

local function geterror(assertion_message, failure_message, args)
  if util.hastostring(failure_message) then
    failure_message = tostring(failure_message)
  elseif failure_message ~= nil then
    failure_message = astate.format_argument(failure_message)
  end
  local message = s(assertion_message, obj:format(args))
  if message and failure_message then
    message = failure_message .. "\n" .. message
  end
  return message or failure_message
end

local __state_meta = {

  __call = function(self, ...)
    local keys = util.extract_keys("assertion", self.tokens)

    local assertion

    for _, key in ipairs(keys) do
      assertion = namespace.assertion[key] or assertion
    end

    if assertion then
      for _, key in ipairs(keys) do
        if namespace.modifier[key] then
          namespace.modifier[key].callback(self)
        end
      end

      local arguments = {...}
      arguments.n = select('#', ...) -- add argument count for trailing nils
      local val, retargs = assertion.callback(self, arguments, util.errorlevel())

      if not val == self.mod then
        local message = assertion.positive_message
        if not self.mod then
          message = assertion.negative_message
        end
        local err = geterror(message, rawget(self,"failure_message"), arguments)
        error(err or "assertion failed!", util.errorlevel())
      end

      if retargs then
        return unpack(retargs)
      end
      return ...
    else
      local arguments = {...}
      arguments.n = select('#', ...)
      self.tokens = {}

      for _, key in ipairs(keys) do
        if namespace.modifier[key] then
          namespace.modifier[key].callback(self, arguments, util.errorlevel())
        end
      end
    end

    return self
  end,

  __index = function(self, key)
    for token in key:lower():gmatch('[^_]+') do
      table.insert(self.tokens, token)
    end

    return self
  end
}

obj = {
  state = function() return setmetatable({mod=true, tokens={}}, __state_meta) end,

  -- registers a function in namespace
  register = function(self, nspace, name, callback, positive_message, negative_message)
    local lowername = name:lower()
    if not namespace[nspace] then
      namespace[nspace] = {}
    end
    namespace[nspace][lowername] = {
      callback = callback,
      name = lowername,
      positive_message=positive_message,
      negative_message=negative_message
    }
  end,

  -- unregisters a function in a namespace
  unregister = function(self, nspace, name)
    local lowername = name:lower()
    if not namespace[nspace] then
      namespace[nspace] = {}
    end
    namespace[nspace][lowername] = nil
  end,

  -- registers a formatter
  -- a formatter takes a single argument, and converts it to a string, or returns nil if it cannot format the argument
  add_formatter = function(self, callback)
    astate.add_formatter(callback)
  end,

  -- unregisters a formatter
  remove_formatter = function(self, fmtr)
    astate.remove_formatter(fmtr)
  end,

  format = function(self, args)
    -- args.n specifies the number of arguments in case of 'trailing nil' arguments which get lost
    local nofmt = args.nofmt or {}  -- arguments in this list should not be formatted
    local fmtargs = args.fmtargs or {} -- additional arguments to be passed to formatter
    for i = 1, (args.n or #args) do -- cannot use pairs because table might have nils
      if not nofmt[i] then
        local val = args[i]
        local valfmt = astate.format_argument(val, nil, fmtargs[i])
        if valfmt == nil then valfmt = tostring(val) end -- no formatter found
        args[i] = valfmt
      end
    end
    return args
  end,

  set_parameter = function(self, name, value)
    astate.set_parameter(name, value)
  end,
  
  get_parameter = function(self, name)
    return astate.get_parameter(name)
  end,  
  
  add_spy = function(self, spy)
    astate.add_spy(spy)
  end,
  
  snapshot = function(self)
    return astate.snapshot()
  end,
  
  level = function(self, level)
    return setmetatable({
        level = level
      }, level_mt)
  end,
  
  -- returns the level if a level-value, otherwise nil
  get_level = function(self, level)
    if getmetatable(level) ~= level_mt then
      return nil -- not a valid error-level
    end
    return level.level
  end,
}

local __meta = {

  __call = function(self, bool, message, level, ...)
    if not bool then
      local err_level = (self:get_level(level) or 1) + 1
      error(message or "assertion failed!", err_level)
    end
    return bool , message , level , ...
  end,

  __index = function(self, key)
    return rawget(self, key) or self.state()[key]
  end,

}

return setmetatable(obj, __meta)