Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document poll_pending invariant #499

Open
cheesycod opened this issue Dec 7, 2024 · 4 comments
Open

Document poll_pending invariant #499

cheesycod opened this issue Dec 7, 2024 · 4 comments

Comments

@cheesycod
Copy link

When making a task scheduler for luau using mlua, I ran into the issue where resuming a Thread which was currently calling a Rust async function could deadlock. This is due to the undocumented variant of the lightuserdata mlua::Lua::poll_pending() being returned which seems to mean that all cases where this is returned should be stored and periodically resumed until completion of the thread.

Could this API please be documented?

@cheesycod
Copy link
Author

cheesycod commented Dec 7, 2024

Found a nice issue here too:

--!nocheck
-- Create a coroutine
local co = coroutine.create(function()
    print("LUA: Doing _TEST_ASYNC_WORK(0)")
    _TEST_ASYNC_WORK(5)
    print("LUA: Done with _TEST_ASYNC_WORK(0)")
end)

coroutine.resume(co)
        lua.globals()
            .set(
                "_TEST_ASYNC_WORK",
                lua.create_async_function(|lua, n: u64| async move {
                    //let task_mgr = taskmgr::get(&lua);
                    println!("Async work: {}", n);
                    tokio::time::sleep(std::time::Duration::from_secs(n)).await;
                    println!("Async work done: {}", n);

                    let created_table = lua.create_table()?;
                    created_table.set("test", "test")?;

                    Ok(created_table)
                })
                .expect("Failed to create async function"),
            )
            .expect("Failed to set _OS global");
    let f = lua
        .load(fs::read_to_string(&path).await?)
        .set_name(fs::canonicalize(&path).await?.to_string_lossy())
        .call_async(mlua::MultiValue::new())
        .await?;

This prints

Running script: "tests/tasks_with_yield.luau"
LUA: Doing _TEST_ASYNC_WORK(0)
Async work: 5

But, you can get this to deadlock with a simple modification:

--!nocheck
function taskA()
-- Licensed under the MIT license
-- See https://gist.github.com/jackdotink/5cd1757f599ba13d37f447fd7f41604c

local function resume_with_error_check(thread: thread, ...: any): ()
    local success, message = coroutine.resume(thread, ...)

    if not success then
        print(string.char(27) .. "[31m" .. message)
    end
end

type Task<T...> = thread | (T...) -> ...any

local last_tick = os.clock()

local waiting_threads: { [thread]: { resume: number } & ({ start: number } | { start: nil, n: number, [number]: any }) } =
    {}

local function process_waiting(): ()
    local processing = waiting_threads
    waiting_threads = {}

    for thread, data in processing do
        if coroutine.status(thread) == "dead" then
        elseif type(data) == "table" and last_tick >= data.resume then
            if data.start then
                resume_with_error_check(thread, last_tick - data.start)
            else
                resume_with_error_check(thread, table.unpack(data, 1, data.n))
            end
        else
            waiting_threads[thread] = data
        end
    end
end

local function wait(time: number): number
    waiting_threads[coroutine.running()] = { resume = last_tick + time, start = last_tick }
    return coroutine.yield()
end

local function delay<T...>(time: number, task: Task<T...>, ...: T...): thread
    local thread = if type(task) == "thread" then task else coroutine.create(task)

    local data: { [any]: any } = table.pack(...)
    data.resume = last_tick + time
    waiting_threads[thread] = data

    return thread
end

local deferred_threads: { { thread: thread, args: { [number]: any, n: number } } } = {}

local function process_deferred(): ()
    local i = 1

    while i <= #deferred_threads do
        local data = deferred_threads[i]

        if coroutine.status(data.thread) ~= "dead" then
            resume_with_error_check(data.thread, table.unpack(data.args))
        end

        i += 1
    end

    table.clear(deferred_threads)
end

local function defer<T...>(task: Task<T...>, ...: T...): thread
    local thread = if type(task) == "thread" then task else coroutine.create(task)
    table.insert(deferred_threads, { thread = thread, args = table.pack(...) })

    return thread
end

local function spawn<T...>(task: Task<T...>, ...: T...): thread
    local thread = if type(task) == "thread" then task else coroutine.create(task)
    resume_with_error_check(thread, ...)

    return thread
end

local function close(thread: thread): ()
    coroutine.close(thread)
end

local function start(): never
    while true do
        last_tick = os.clock()

        process_waiting()
        process_deferred()
    end
end

return {
    wait = wait,
    delay = delay,
    defer = defer,
    spawn = spawn,
    close = close,
    start = start,
}
end

local task = taskA()

task.defer(function() 
print("LUA: Wait 1 second")
task.wait(1)
print("LUA: Doing _TEST_ASYNC_WORK(1)")
_TEST_ASYNC_WORK(1)
print("LUA: Done with _TEST_ASYNC_WORK(1)")
print("LUA: Wait 1 second")
task.wait(1)
print("LUA: Doing _TEST_ASYNC_WORK(2)")
_TEST_ASYNC_WORK(2)
print("LUA: Done with _TEST_ASYNC_WORK(2)")

print("LUA: Wait 1 second")
task.wait(1)
print("All done")
end)

task.start()

This is especially bad when dealing with untrusted user input as they can deadlock the whole execution entirely

@khvzak

@cheesycod cheesycod reopened this Dec 7, 2024
@khvzak
Copy link
Member

khvzak commented Dec 7, 2024

This is by design and not a deadlock.

Rust futures return Poll::Pending when result is not ready for (tokio) scheduler that does the work behind the scenes.
When calling Rust async functions inside Lua coroutines, pending value must be propagated back to Rust. mlua does this but you taking over the control by calling coroutine.resume manually.
You need to ensure that if result of coroutine.resume is pending then yield it back to mlua.

You can also patch coroutine.resume globally to do the check.

@cheesycod
Copy link
Author

This is by design and not a deadlock.

Rust futures return Poll::Pending when result is not ready for (tokio) scheduler that does the work behind the scenes. When calling Rust async functions inside Lua coroutines, pending value must be propagated back to Rust. mlua does this but you taking over the control by calling coroutine.resume manually. You need to ensure that if result of coroutine.resume is pending then yield it back to mlua.

You can also patch coroutine.resume globally to do the check.

Thanks for the reply. Do you mind sending an example rust implementation of coroutine.resume that does the yielding till done.

@khvzak
Copy link
Member

khvzak commented Dec 7, 2024

lua.load(
    r#"
    local pending = ...
    local resume = coroutine.resume
    coroutine.resume = function(co, ...)
        while true do
            local res = { resume(co, ...) }
            if res[1] == true and res[2] == pending then
                coroutine.yield(pending)
            else
                return table.unpack(res)
            end
        end
    end
"#,
)
.call::<()>(mlua::Lua::poll_pending())?;

this should be called before any sandboxing

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants