-- maintains a state of the assert engine in a linked-list fashion
-- records; formatters, parameters, spies and stubs

local state_mt = {
  __call = function(self)
    self:revert()
  end
}

local spies_mt = { __mode = "kv" }

local nilvalue = {} -- unique ID to refer to nil values for parameters

-- will hold the current state
local current

-- exported module table
local state = {}

------------------------------------------------------
-- Reverts to a (specific) snapshot.
-- @param self (optional) the snapshot to revert to. If not provided, it will revert to the last snapshot.
state.revert = function(self)
  if not self then
    -- no snapshot given, so move 1 up
    self = current
    if not self.previous then
      -- top of list, no previous one, nothing to do
      return
    end
  end
  if getmetatable(self) ~= state_mt then error("Value provided is not a valid snapshot", 2) end
  
  if self.next then
    self.next:revert()
  end
  -- revert formatters in 'last'
  self.formatters = {}
  -- revert parameters in 'last'
  self.parameters = {}
  -- revert spies/stubs in 'last'
  for s,_ in pairs(self.spies) do
    self.spies[s] = nil
    s:revert()
  end
  setmetatable(self, nil) -- invalidate as a snapshot
  current = self.previous
  current.next = nil
end

------------------------------------------------------
-- Creates a new snapshot.
-- @return snapshot table
state.snapshot = function()
  local s = current
  local new = setmetatable ({
    formatters = {},
    parameters = {},
    spies = setmetatable({}, spies_mt),
    previous = current,
    revert = state.revert,
  }, state_mt)
  if current then current.next = new end
  current = new
  return current
end


--  FORMATTERS
state.add_formatter = function(callback)
  table.insert(current.formatters, 1, callback)
end

state.remove_formatter = function(callback, s)
  s = s or current
  for i, v in ipairs(s.formatters) do
    if v == callback then
      table.remove(s.formatters, i)
      break
    end
  end
  -- wasn't found, so traverse up 1 state
  if s.previous then
    state.remove_formatter(callback, s.previous)
  end
end

state.format_argument = function(val, s, fmtargs)
  s = s or current
  for _, fmt in ipairs(s.formatters) do
    local valfmt = fmt(val, fmtargs)
    if valfmt ~= nil then return valfmt end
  end
  -- nothing found, check snapshot 1 up in list
  if s.previous then
    return state.format_argument(val, s.previous, fmtargs)
  end
  return nil -- end of list, couldn't format
end


--  PARAMETERS
state.set_parameter = function(name, value)
  if value == nil then value = nilvalue end
  current.parameters[name] = value
end

state.get_parameter = function(name, s)
  s = s or current
  local val = s.parameters[name]
  if val == nil and s.previous then
    -- not found, so check 1 up in list
    return state.get_parameter(name, s.previous)
  end
  if val ~= nilvalue then
    return val
  end
  return nil
end

--  SPIES / STUBS
state.add_spy = function(spy)
  current.spies[spy] = true
end

state.snapshot()  -- create initial state

return state