---@brief [[ ---An adaptation of luafun for neovim. ---This library will use neovim specific functions. ---Some documentation is the same as from luafun. ---Some extra functions are present that are not in luafun ---@brief ]] local co = coroutine local f = require "plenary.functional" -------------------------------------------------------------------------------- -- Tools -------------------------------------------------------------------------------- local exports = {} ---@class Iterator ---@field gen function ---@field param any ---@field state any local Iterator = {} Iterator.__index = Iterator ---Makes a for loop work ---If not called without param or state, will just generate with the starting state ---This is useful because the original luafun will also return param and state in addition to the iterator as a multival ---This can cause problems because when using iterators as expressions the multivals can bleed ---For example i.iter { 1, 2, i.iter { 3, 4 } } will not work because the inner iterator returns a multival thus polluting the list with internal values ---So instead we do not return param and state as multivals when doing wrap ---This causes the first loop iteration to call param and state with nil because we didn't return them as multivals ---We have to use or to check for nil and default to interal starting state and param function Iterator:__call(param, state) return self.gen(param or self.param, state or self.state) end function Iterator:__tostring() return "" end -- A special hack for zip/chain to skip last two state, if a wrapped iterator -- has been passed local numargs = function(...) local n = select("#", ...) if n >= 3 then -- Fix last argument local it = select(n - 2, ...) if type(it) == "table" and getmetatable(it) == Iterator and it.param == select(n - 1, ...) and it.state == select(n, ...) then return n - 2 end end return n end local return_if_not_empty = function(state_x, ...) if state_x == nil then return nil end return ... end local call_if_not_empty = function(fun, state_x, ...) if state_x == nil then return nil end return state_x, fun(...) end -------------------------------------------------------------------------------- -- Basic Functions -------------------------------------------------------------------------------- local nil_gen = function(_param, _state) return nil end local ipairs_gen = ipairs {} local pairs_gen = pairs {} local map_gen = function(map, key) key, value = pairs_gen(map, key) return key, key, value end local string_gen = function(param, state) state = state + 1 if state > #param then return nil end local r = string.sub(param, state, state) return state, r end local rawiter = function(obj, param, state) assert(obj ~= nil, "invalid iterator") if type(obj) == "table" then local mt = getmetatable(obj) if mt ~= nil then if mt == Iterator then return obj.gen, obj.param, obj.state end end if vim.tbl_islist(obj) then return ipairs(obj) else -- hash return map_gen, obj, nil end elseif type(obj) == "function" then return obj, param, state elseif type(obj) == "string" then if #obj == 0 then return nil_gen, nil, nil end return string_gen, obj, 0 end error(string.format('object %s of type "%s" is not iterable', obj, type(obj))) end ---Wraps the iterator triplet into a table to allow metamethods and calling with method form ---Important! We do not return param and state as multivals like the original luafun ---Se the __call metamethod for more information ---@param gen any ---@param param any ---@param state any ---@return Iterator local function wrap(gen, param, state) return setmetatable({ gen = gen, param = param, state = state, }, Iterator) end ---Unwrap an iterator metatable into the iterator triplet ---@param self Iterator ---@return any ---@return any ---@return any local unwrap = function(self) return self.gen, self.param, self.state end ---Create an iterator from an object ---@param obj any ---@param param any (optional) ---@param state any (optional) ---@return Iterator local iter = function(obj, param, state) return wrap(rawiter(obj, param, state)) end exports.iter = iter exports.wrap = wrap exports.unwrap = unwrap function Iterator:for_each(fn) local param, state = self.param, self.state repeat state = call_if_not_empty(fn, self.gen(param, state)) until state == nil end function Iterator:stateful() return wrap( co.wrap(function() self:for_each(function(...) co.yield(f.first(...), ...) end) -- too make sure that we always return nil if there are no more while true do co.yield() end end), nil, nil ) end -- function Iterator:stateful() -- local gen, param, state = self.gen, self.param, self.state -- local function return_and_set_state(state_x, ...) -- state = state_x -- if state == nil then return end -- return state_x, ... -- end -- local stateful_gen = function() -- return return_and_set_state(gen(param, state)) -- end -- return wrap(stateful_gen, false, false) -- end -------------------------------------------------------------------------------- -- Generators -------------------------------------------------------------------------------- local range_gen = function(param, state) local stop, step = param[1], param[2] state = state + step if state > stop then return nil end return state, state end local range_rev_gen = function(param, state) local stop, step = param[1], param[2] state = state + step if state < stop then return nil end return state, state end ---Creates a range iterator ---@param start number ---@param stop number ---@param step number ---@return Iterator local range = function(start, stop, step) if step == nil then if stop == nil then if start == 0 then return nil_gen, nil, nil end stop = start start = stop > 0 and 1 or -1 end step = start <= stop and 1 or -1 end assert(type(start) == "number", "start must be a number") assert(type(stop) == "number", "stop must be a number") assert(type(step) == "number", "step must be a number") assert(step ~= 0, "step must not be zero") if step > 0 then return wrap(range_gen, { stop, step }, start - step) elseif step < 0 then return wrap(range_rev_gen, { stop, step }, start - step) end end exports.range = range local duplicate_table_gen = function(param_x, state_x) return state_x + 1, unpack(param_x) end local duplicate_fun_gen = function(param_x, state_x) return state_x + 1, param_x(state_x) end local duplicate_gen = function(param_x, state_x) return state_x + 1, param_x end ---Creates an infinite iterator that will yield the arguments ---If multiple arguments are passed, the args will be packed and unpacked ---@param ...: the arguments to duplicate ---@return Iterator local duplicate = function(...) if select("#", ...) <= 1 then return wrap(duplicate_gen, select(1, ...), 0) else return wrap(duplicate_table_gen, { ... }, 0) end end exports.duplicate = duplicate ---Creates an iterator from a function ---NOTE: if the function is a closure and modifies state, the resulting iterator will not be stateless ---@param fun function ---@return Iterator local from_fun = function(fun) assert(type(fun) == "function") return wrap(duplicate_fun_gen, fun, 0) end exports.from_fun = from_fun ---Creates an infinite iterator that will yield zeros. ---This is an alias to calling duplicate(0) ---@return Iterator local zeros = function() return wrap(duplicate_gen, 0, 0) end exports.zeros = zeros ---Creates an infinite iterator that will yield ones. ---This is an alias to calling duplicate(1) ---@return Iterator local ones = function() return wrap(duplicate_gen, 1, 0) end exports.ones = ones local rands_gen = function(param_x, _state_x) return 0, math.random(param_x[1], param_x[2]) end local rands_nil_gen = function(_param_x, _state_x) return 0, math.random() end ---Creates an infinite iterator that will yield random values. ---@param n number ---@param m number ---@return Iterator local rands = function(n, m) if n == nil and m == nil then return wrap(rands_nil_gen, 0, 0) end assert(type(n) == "number", "invalid first arg to rands") if m == nil then m = n n = 0 else assert(type(m) == "number", "invalid second arg to rands") end assert(n < m, "empty interval") return wrap(rands_gen, { n, m - 1 }, 0) end exports.rands = rands local split_gen = function(param, state) local input, sep = param[1], param[2] local input_len = #input if state > input_len + 1 then return end local start, finish = string.find(input, sep, state, true) if not start then start = input_len + 1 finish = input_len + 1 end local sub_str = input:sub(state, start - 1) return finish + 1, sub_str end ---Return an iterator of substrings separated by a string ---@param input string: the string to split ---@param sep string: the separator to find and split based on ---@return Iterator local split = function(input, sep) return wrap(split_gen, { input, sep }, 1) end exports.split = split ---Splits a string based on a single space ---An alias for split(input, " ") ---@param input any ---@return any local words = function(input) return split(input, " ") end exports.words = words local lines = function(input) -- TODO: platform specific linebreaks return split(input, "\n") end exports.lines = lines -------------------------------------------------------------------------------- -- Transformations -------------------------------------------------------------------------------- local map_gen = function(param, state) local gen_x, param_x, fun = param[1], param[2], param[3] return call_if_not_empty(fun, gen_x(param_x, state)) end ---Iterator adapter that maps the previous iterator with a function ---@param fun function: The function to map with. Will be called on each element ---@return Iterator function Iterator:map(fun) return wrap(map_gen, { self.gen, self.param, fun }, self.state) end local flatten_gen1 do local it = function(new_iter, state_x, ...) if state_x == nil then return nil end return { new_iter.gen, new_iter.param, state_x }, ... end flatten_gen1 = function(state, state_x, ...) if state_x == nil then return nil end local first_arg = f.first(...) -- experimental part if getmetatable(first_arg) == Iterator then -- attach the iterator to the rest local new_iter = (first_arg .. wrap(state[1], state[2], state_x)):flatten() -- advance the iterator by one return it(new_iter, new_iter.gen(new_iter.param, new_iter.state)) end return { state[1], state[2], state_x }, ... end end local flatten_gen = function(_, state) if state == nil then return end local gen_x, param_x, state_x = state[1], state[2], state[3] return flatten_gen1(state, gen_x(param_x, state_x)) end ---Iterator adapter that will recursivley flatten nested iterator structure ---@return Iterator function Iterator:flatten() return wrap(flatten_gen, false, { self.gen, self.param, self.state }) end -------------------------------------------------------------------------------- -- Filtering -------------------------------------------------------------------------------- local filter1_gen = function(fun, gen_x, param_x, state_x, a) while true do if state_x == nil or fun(a) then break end state_x, a = gen_x(param_x, state_x) end return state_x, a end -- call each other -- because we can't assign a vararg mutably in a while loop like filter1_gen -- so we have to use recursion in calling both of these functions local filterm_gen local filterm_gen_shrink = function(fun, gen_x, param_x, state_x) return filterm_gen(fun, gen_x, param_x, gen_x(param_x, state_x)) end filterm_gen = function(fun, gen_x, param_x, state_x, ...) if state_x == nil then return nil end if fun(...) then return state_x, ... end return filterm_gen_shrink(fun, gen_x, param_x, state_x) end local filter_detect = function(fun, gen_x, param_x, state_x, ...) if select("#", ...) < 2 then return filter1_gen(fun, gen_x, param_x, state_x, ...) else return filterm_gen(fun, gen_x, param_x, state_x, ...) end end local filter_gen = function(param, state_x) local gen_x, param_x, fun = param[1], param[2], param[3] return filter_detect(fun, gen_x, param_x, gen_x(param_x, state_x)) end ---Iterator adapter that will filter values ---@param fun function: The function to filter values with. If the function returns true, the value will be kept. ---@return Iterator function Iterator:filter(fun) return wrap(filter_gen, { self.gen, self.param, fun }, self.state) end -------------------------------------------------------------------------------- -- Reducing -------------------------------------------------------------------------------- ---Returns true if any of the values in the iterator satisfy a predicate ---@param fun function ---@return boolean function Iterator:any(fun) local r local state, param, gen = self.state, self.param, self.gen repeat state, r = call_if_not_empty(fun, gen(param, state)) until state == nil or r return r end ---Returns true if all of the values in the iterator satisfy a predicate ---@param fun function ---@return boolean function Iterator:all(fun) local r local state, param, gen = self.state, self.param, self.gen repeat state, r = call_if_not_empty(fun, gen(param, state)) until state == nil or not r return state == nil end ---Finds a value that is equal to the provided value of satisfies a predicate. ---@param val_or_fn any ---@return any function Iterator:find(val_or_fn) local gen, param, state = self.gen, self.param, self.state if type(val_or_fn) == "function" then return return_if_not_empty(filter_detect(val_or_fn, gen, param, gen(param, state))) else for _, r in gen, param, state do if r == val_or_fn then return r end end return nil end end ---Turns an iterator into a list. ---If the iterator yields multivals only the first multival will be used. ---@return table function Iterator:tolist() local list = {} self:for_each(function(a) table.insert(list, a) end) return list end ---Turns an iterator into a list. ---If the iterator yields multivals all multivals will be used and packed into a table. ---@return table function Iterator:tolistn() local list = {} self:for_each(function(...) table.insert(list, { ... }) end) return list end ---Turns an iterator into a map. ---The first multival that the iterator yields will be the key. ---The second multival that the iterator yields will be the value. ---@return table function Iterator:tomap() local map = {} self:for_each(function(key, value) map[key] = value end) return map end -------------------------------------------------------------------------------- -- Compositions -------------------------------------------------------------------------------- -- call each other local chain_gen_r1 local chain_gen_r2 = function(param, state, state_x, ...) if state_x == nil then local i = state[1] + 1 if param[3 * i - 1] == nil then return nil end state_x = param[3 * i] return chain_gen_r1(param, { i, state_x }) end return { state[1], state_x }, ... end chain_gen_r1 = function(param, state) local i, state_x = state[1], state[2] local gen_x, param_x = param[3 * i - 2], param[3 * i - 1] return chain_gen_r2(param, state, gen_x(param_x, state_x)) end ---Make an iterator that returns elements from the first iterator until it is exhausted, then proceeds to the next iterator, ---until all of the iterators are exhausted. ---Used for treating consecutive iterators as a single iterator. ---Infinity iterators are supported, but are not recommended. ---@param ...: the iterators to chain ---@return Iterator local chain = function(...) local n = numargs(...) if n == 0 then return wrap(nil_gen, nil, nil) end local param = { [3 * n] = 0 } local i, gen_x, param_x, state_x for i = 1, n, 1 do local elem = select(i, ...) gen_x, param_x, state_x = unwrap(elem) param[3 * i - 2] = gen_x param[3 * i - 1] = param_x param[3 * i] = state_x end return wrap(chain_gen_r1, param, { 1, param[3] }) end Iterator.chain = chain Iterator.__concat = chain exports.chain = chain local function zip_gen_r(param, state, state_new, ...) if #state_new == #param / 2 then return state_new, ... end local i = #state_new + 1 local gen_x, param_x = param[2 * i - 1], param[2 * i] local state_x, r = gen_x(param_x, state[i]) if state_x == nil then return nil end table.insert(state_new, state_x) return zip_gen_r(param, state, state_new, r, ...) end local zip_gen = function(param, state) return zip_gen_r(param, state, {}) end ---Return a new iterator where i-th return value contains the i-th element from each of the iterators. ---The returned iterator is truncated in length to the length of the shortest iterator. ---For multi-return iterators only the first variable is used. ---@param ...: the iterators to zip ---@return Iterator local zip = function(...) local n = numargs(...) if n == 0 then return wrap(nil_gen, nil, nil) end local param = { [2 * n] = 0 } local state = { [n] = 0 } local i, gen_x, param_x, state_x for i = 1, n, 1 do local it = select(n - i + 1, ...) gen_x, param_x, state_x = rawiter(it) param[2 * i - 1] = gen_x param[2 * i] = param_x state[i] = state_x end return wrap(zip_gen, param, state) end Iterator.zip = zip Iterator.__div = zip exports.zip = zip return exports