Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Interrupt on windows! 😯 #60

Merged
merged 20 commits into from
Sep 27, 2023
Merged
144 changes: 132 additions & 12 deletions src/Malt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,57 @@ include("./shared.jl")



"""
```julia
poll(query::Function, timeout::Real=Inf64, interval::Real=1/20)::Bool
```

Keep running your function `query()` in intervals until it returns `true`, or until `timeout` seconds have passed.

`poll` returns `true` if `query()` returned `true`. If `timeout` seconds have passed, `poll` returns `false`.

# Example
```julia
vals = [1,2,3]

@async for i in 1:5
sleep(1)
vals[3] = 99
end

poll(8 #= seconds =#) do
vals[3] == 99
end # returns `true` (after 5 seconds)!

###

@async for i in 1:5
sleep(1)
vals[3] = 5678
end

poll(2 #= seconds =#) do
vals[3] == 5678
end # returns `false` (after 2 seconds)!
```
"""
function poll(query::Function, timeout::Real=Inf64, interval::Real=1/20)
start = time()
while time() < start + timeout
if query()
return true
end
sleep(interval)
end
return false
end







abstract type AbstractWorker end

"""
Expand All @@ -44,6 +95,8 @@ end
function unwrap_worker_result(worker::AbstractWorker, result::WorkerResult)
if result.msg_type == MsgType.special_serialization_failure
throw(ErrorException("Error deserializing data from $(summary(worker)):\n\n$(sprint(Base.showerror, result.value))"))
elseif result.msg_type == MsgType.special_worker_terminated
throw(TerminatedWorkerException())
elseif result.msg_type == MsgType.from_worker_call_failure
throw(RemoteException(worker, result.value))
else
Expand Down Expand Up @@ -90,6 +143,7 @@ Malt.Worker(0x0000, Process(`…`, ProcessRunning))
mutable struct Worker <: AbstractWorker
port::UInt16
proc::Base.Process
proc_pid::Int32

current_socket::Sockets.TCPSocket
# socket_lock::ReentrantLock
Expand All @@ -100,7 +154,11 @@ mutable struct Worker <: AbstractWorker
function Worker(; env=String[], exeflags=[])
# Spawn process
cmd = _get_worker_cmd(; env, exeflags)
proc = open(cmd, "w+")
proc = open(Cmd(
cmd;
detach=Sys.iswindows(),
windows_hide=true,
), "w+")

# Keep internal list
__iNtErNaL_get_running_procs()
Expand All @@ -117,7 +175,14 @@ mutable struct Worker <: AbstractWorker

# There's no reason to keep the worker process alive after the manager loses its handle.
w = finalizer(w -> @async(stop(w)),
new(port, proc, socket, MsgID(0), Dict{MsgID,Channel{WorkerResult}}())
new(
port,
proc,
getpid(proc),
socket,
MsgID(0),
Dict{MsgID,Channel{WorkerResult}}(),
)
)
atexit(() -> stop(w))

Expand All @@ -127,17 +192,32 @@ mutable struct Worker <: AbstractWorker
end
end

Base.summary(io::IO, w::Worker) = write(io, "Malt.Worker on port $(w.port)")
Base.summary(io::IO, w::Worker) = write(io, "Malt.Worker on port $(w.port) with PID $(w.proc_pid)")


function _receive_loop(worker::Worker)
io = worker.current_socket

exit_handler_task = @async for _i in Iterators.countfrom(1)
try
if !isrunning(worker)
for c in values(worker.expected_replies)
isready(c) || put!(c, WorkerResult(MsgType.special_worker_terminated, nothing))
end
break
end
sleep(1)
catch e
@error "asdfdfs" exception=(e,catch_backtrace())
end
end

# Here we use:
# `for _i in Iterators.countfrom(1)`
# instead of
# `while true`
# as a workaround for https://github.com/JuliaLang/julia/issues/37154
@async for _i in Iterators.countfrom(1)
listen_task = @async for _i in Iterators.countfrom(1)
try
if !isopen(io)
@debug("HOST: io closed.")
Expand Down Expand Up @@ -581,18 +661,58 @@ Send an interrupt signal to the worker process. This will interrupt the
latest request (`remote_call*` or `remote_eval*`) that was sent to the worker.
"""
function interrupt(w::Worker)
if Sys.iswindows()
# TODO: not yet implemented
@warn "Malt.interrupt is not yet supported on Windows"
# _assert_is_running(w)
# _send_msg(w, MsgType.from_host_fake_interrupt, (), false)
nothing
if !isrunning(w)
@warn "Tried to interrupt a worker that has already stopped running." summary(w)
else
Base.kill(w.proc, Base.SIGINT)
if Sys.iswindows()
ccall((:GenerateConsoleCtrlEvent,"Kernel32"), Bool, (UInt32, UInt32), UInt32(1), UInt32(getpid(w.proc)))
fonsp marked this conversation as resolved.
Show resolved Hide resolved
else
Base.kill(w.proc, Base.SIGINT)
end
end
nothing
end
function interrupt(w::InProcessWorker)
schedule(w.latest_request_task, InterruptException(); error=true)
isdone(w.latest_request_task) || schedule(w.latest_request_task, InterruptException(); error=true)
nothing
end


function interrupt_auto(w::AbstractWorker; verbose::Bool=true)
t = remote_call(&, w, true, true)

done() = !isrunning(w) || istaskdone(t)

try
verbose && @info "Sending interrupt to process $(w)"
interrupt(w)

if poll(() -> done(), 5.0, 5/100)
verbose && println("Cell interrupted!")
return true
end

verbose && println("Still running... starting sequence")
while !done()
for _ in 1:5
verbose && print(" 🔥 ")
interrupt(w)
sleep(0.18)
if done()
break
end
end
sleep(1.5)
end
verbose && println()
verbose && println("Cell interrupted!")
true
catch e
# if !(e isa KeyError)
@warn "Interrupt failed for unknown reason" exception=(e,catch_backtrace())
# end
false
end
end


Expand Down
1 change: 1 addition & 0 deletions src/shared.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const MsgType = (
from_worker_call_failure = UInt8(81),
###
special_serialization_failure = UInt8(100),
special_worker_terminated = UInt8(101),
)

const MsgID = UInt64
Expand Down
101 changes: 101 additions & 0 deletions test/interrupt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# @testset "Interrupt: $W" for W in (m.DistributedStdlibWorker, m.InProcessWorker, m.Worker)
@testset "Interrupt: $W" for W in (m.Worker,)


w = W()

# m.interrupt(w)
@test m.isrunning(w)
@test m.remote_call_fetch(&, w, true, true)


ex1 = quote
local x = 0.0
for i in 1:20_000_000
x += sqrt(abs(sin(cos(tan(x)))))^(1/i)
end
x
end

exs = [
ex1,
quote
sleep(3)
end,
ex1, # second time because interrupts should be reliable
]

@testset "single interrupt $ex" for ex in exs

f() = m.remote_eval(w, ex)

t1 = @elapsed wait(f())
t2 = @elapsed wait(f())
@info "first run" t1 t2

t3 = @elapsed begin
t = f()
@test !istaskdone(t)
m.interrupt(w)
@test try
wait(t)
nothing
catch e
e
end isa TaskFailedException
# @test t.exception isa InterruptException
end

t4 = @elapsed begin
t = f()
@test !istaskdone(t)
m.interrupt(w)
@test try
wait(t)
nothing
catch e
e
end isa TaskFailedException
# @test t.exception isa InterruptException
end

@info "test run" t1 t2 t3 t4
@test t4 < min(t1,t2) * 0.8

# still running and responsive
@test m.isrunning(w)
@test m.remote_call_fetch(&, w, true, true)

end

@testset "hard interrupt" begin
t = m.remote_eval(w, :(while true end))

@test !istaskdone(t)
@test m.isrunning(w)
m.interrupt_auto(w)
@info "xx" istaskdone(t) m.isrunning(w)

@test try
wait(t)
nothing
catch e
e
end isa TaskFailedException

# hello
@test true

if Sys.iswindows()
@info "Interrupt done" m.isrunning(w)
else
# still running and responsive
@test m.isrunning(w)
@test m.remote_call_fetch(&, w, true, true)
end
end


m.stop(w)
@test !m.isrunning(w)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using Test

v() = @assert isempty(m.__iNtErNaL_get_running_procs())

v()
include("interrupt.jl")
v()
include("basic.jl")
v()
Expand Down
Loading