local co = coroutine local errors = require "plenary.errors" local traceback_error = errors.traceback_error local f = require "plenary.functional" local tbl = require "plenary.tbl" local M = {} ---because we can't store varargs local function callback_or_next(step, thread, callback, ...) local stat = f.first(...) if not stat then error(string.format("The coroutine failed with this message: %s", f.second(...))) end if co.status(thread) == "dead" then (callback or function() end)(select(2, ...)) else assert(select("#", select(2, ...)) == 1, "expected a single return value") local returned_future = f.second(...) assert(type(returned_future) == "function", "type error :: expected func") returned_future(step) end end ---@class Future ---Something that will give a value when run ---Executes a future with a callback when it is done ---@param future Future: the future to execute ---@param callback function: the callback to call when done local execute = function(future, callback) assert(type(future) == "function", "type error :: expected func") local thread = co.create(future) local step step = function(...) callback_or_next(step, thread, callback, co.resume(thread, ...)) end step() end ---Creates an async function with a callback style function. ---@param func function: A callback style function to be converted. The last argument must be the callback. ---@param argc number: The number of arguments of func. Must be included. ---@return function: Returns an async function M.wrap = function(func, argc) if type(func) ~= "function" then traceback_error("type error :: expected func, got " .. type(func)) end if type(argc) ~= "number" and argc ~= "vararg" then traceback_error "expected argc to be a number or string literal 'vararg'" end return function(...) local params = tbl.pack(...) local function future(step) if step then if type(argc) == "number" then params[argc] = step params.n = argc else table.insert(params, step) -- change once not optional params.n = params.n + 1 end return func(tbl.unpack(params)) else return co.yield(future) end end return future end end ---Return a new future that when run will run all futures concurrently. ---@param futures table: the futures that you want to join ---@return Future: returns a future M.join = M.wrap(function(futures, step) local len = #futures local results = {} local done = 0 if len == 0 then return step(results) end for i, future in ipairs(futures) do assert(type(future) == "function", "type error :: future must be function") local callback = function(...) results[i] = { ... } done = done + 1 if done == len then step(results) end end future(callback) end end, 2) ---Returns a future that when run will select the first future that finishes ---@param futures table: The future that you want to select ---@return Future M.select = M.wrap(function(futures, step) local selected = false for _, future in ipairs(futures) do assert(type(future) == "function", "type error :: future must be function") local callback = function(...) if not selected then selected = true step(...) end end future(callback) end end, 2) ---Use this to either run a future concurrently and then do something else ---or use it to run a future with a callback in a non async context ---@param future Future ---@param callback function M.run = function(future, callback) future(callback or function() end) end ---Same as run but runs multiple futures ---@param futures table ---@param callback function M.run_all = function(futures, callback) M.run(M.join(futures), callback) end ---Await a future, yielding the current function ---@param future Future ---@return any: returns the result of the future when it is done M.await = function(future) assert(type(future) == "function", "type error :: expected function to await") return future(nil) end ---Same as await but can await multiple futures. ---If the futures have libuv leaf futures they will be run concurrently ---@param futures table ---@return table: returns a table of results that each future returned. Note that if the future returns multiple values they will be packed into a table. M.await_all = function(futures) assert(type(futures) == "table", "type error :: expected table") return M.await(M.join(futures)) end ---suspend a coroutine M.suspend = co.yield ---create a async scope M.scope = function(func) M.run(M.future(func)) end --- Future a :: a -> (a -> ()) --- turns this signature --- ... -> Future a --- into this signature --- ... -> () M.void = function(async_func) return function(...) async_func(...)(function() end) end end M.async_void = function(func) return M.void(M.async(func)) end ---creates an async function ---@param func function ---@return function: returns an async function M.async = function(func) if type(func) ~= "function" then traceback_error("type error :: expected func, got " .. type(func)) end return function(...) local args = tbl.pack(...) local function future(step) if step == nil then return func(tbl.unpack(args)) else execute(future, step) end end return future end end ---creates a future ---@param func function ---@return Future M.future = function(func) return M.async(func)() end ---An async function that when awaited will await the scheduler to be able to call the api. M.scheduler = M.wrap(vim.schedule, 1) return M