Skip to content

Commit

Permalink
feat: handle cancelling requests (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
abc3 authored Oct 24, 2023
1 parent 5a0571e commit c6e7335
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 63 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.21
0.9.22
15 changes: 8 additions & 7 deletions lib/supavisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ defmodule Supavisor do

@registry Supavisor.Registry.Tenants

@spec start(id, secrets) :: {:ok, pid} | {:error, any}
def start(id, secrets) do
@spec start(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any}
def start(id, secrets, db_name) do
case get_global_sup(id) do
nil ->
start_local_pool(id, secrets)
start_local_pool(id, secrets, db_name)

pid ->
{:ok, pid}
Expand Down Expand Up @@ -156,8 +156,8 @@ defmodule Supavisor do

## Internal functions

@spec start_local_pool(id, secrets) :: {:ok, pid} | {:error, any}
defp start_local_pool({tenant, user, mode} = id, {method, secrets}) do
@spec start_local_pool(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any}
defp start_local_pool({tenant, user, mode} = id, {method, secrets}, db_name) do
Logger.debug("Starting pool for #{inspect(id)}")

case Tenants.get_pool_config(tenant, secrets.().alias) do
Expand Down Expand Up @@ -194,7 +194,7 @@ defmodule Supavisor do
host: String.to_charlist(db_host),
port: db_port,
user: db_user,
database: db_database,
database: if(db_name != nil, do: db_name, else: db_database),
password: fn -> db_pass end,
application_name: "supavisor",
ip_version: H.ip_version(ip_ver, db_host),
Expand Down Expand Up @@ -235,7 +235,8 @@ defmodule Supavisor do
end
end

@spec set_parameter_status(id, [{binary, binary}]) :: :ok | {:error, :not_found}
@spec set_parameter_status(id, [{binary, binary}]) ::
:ok | {:error, :not_found}
def set_parameter_status(id, ps) do
case get_local_manager(id) do
nil -> {:error, :not_found}
Expand Down
1 change: 1 addition & 0 deletions lib/supavisor/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ defmodule Supavisor.Application do
PromEx,
{Registry, keys: :unique, name: Supavisor.Registry.Tenants},
{Registry, keys: :unique, name: Supavisor.Registry.ManagerTables},
{Registry, keys: :unique, name: Supavisor.Registry.PoolPids},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantSups},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantClients},
{Cluster.Supervisor, [topologies, [name: Supavisor.ClusterSupervisor]]},
Expand Down
47 changes: 44 additions & 3 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ defmodule Supavisor.ClientHandler do
proxy_type: nil,
mode: opts.mode,
stats: %{},
idle_timeout: 0
idle_timeout: 0,
db_name: nil
}

:gen_statem.enter_loop(__MODULE__, [hibernate_after: 5_000], :exchange, data)
Expand All @@ -70,6 +71,31 @@ defmodule Supavisor.ClientHandler do
{:stop, :normal, data}
end

# cancel request
def handle_event(:info, {_, _, <<16::32, 1234::16, 5678::16, pid::32, key::32>>}, _, data) do
Logger.debug("Got cancel query for #{inspect({pid, key})}")
:ok = HH.send_cancel_query(pid, key)
{:stop, :normal, data}
end

# send cancel request to db
def handle_event(:info, :cancel_query, :busy, data) do
key = {data.tenant, data.db_pid}
Logger.debug("Cancel query for #{inspect(key)}")

db_pid = data.db_pid

case db_pid_meta(key) do
[{^db_pid, meta}] ->
:ok = HH.cancel_query(meta.host, meta.port, meta.ip_ver, meta.pid, meta.key)

error ->
Logger.error("Received cancel but no proc was found #{inspect(key)} #{inspect(error)}")
end

:keep_state_and_data
end

def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %{sock: sock} = data) do
Logger.debug("Client is trying to connect with SSL")

Expand Down Expand Up @@ -109,6 +135,7 @@ defmodule Supavisor.ClientHandler do
Logger.debug("Client startup message: #{inspect(hello)}")
{user, external_id} = HH.parse_user_info(hello.payload)
Logger.metadata(project: external_id, user: user, mode: data.mode)
data = %{data | db_name: hello.payload["database"]}
{:keep_state, data, {:next_event, :internal, {:hello, {user, external_id}}}}

{:error, error} ->
Expand Down Expand Up @@ -192,7 +219,7 @@ defmodule Supavisor.ClientHandler do
def handle_event(:internal, :subscribe, _, data) do
Logger.debug("Subscribe to tenant #{inspect(data.id)}")

with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets),
with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets, data.db_name),
{:ok, opts} <- Supavisor.subscribe(sup, data.id) do
Process.monitor(opts.workers.manager)
data = Map.merge(data, opts.workers)
Expand Down Expand Up @@ -221,7 +248,10 @@ defmodule Supavisor.ClientHandler do
end

def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do
:ok = HH.sock_send(sock, Server.greetings(ps))
{header, <<pid::32, key::32>> = payload} = Server.backend_key_data()
msg = [ps, [header, payload], Server.ready_for_query()]
:ok = HH.listen_cancel_query(pid, key)
:ok = HH.sock_send(sock, msg)

if data.idle_timeout > 0 do
{:next_state, :idle, data, idle_check(data.idle_timeout)}
Expand Down Expand Up @@ -632,4 +662,15 @@ defmodule Supavisor.ClientHandler do
defp idle_check(timeout) do
{:timeout, timeout, :idle_terminate}
end

defp db_pid_meta({_, pid} = key) do
rkey = Supavisor.Registry.PoolPids
fnode = node(pid)

if fnode == node() do
Registry.lookup(rkey, key)
else
:erpc.call(fnode, Registry, :lookup, [rkey, key], 15_000)
end
end
end
7 changes: 7 additions & 0 deletions lib/supavisor/db_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ defmodule Supavisor.DbHandler do
%{tag: :ready_for_query, payload: db_state}, {ps, _} ->
{ps, db_state}

%{tag: :backend_key_data, payload: payload}, acc ->
key = {data.tenant, self()}
conn = %{host: data.auth.host, port: data.auth.port, ip_ver: data.auth.ip_version}
Registry.register(Supavisor.Registry.PoolPids, key, Map.merge(payload, conn))
Logger.debug("Backend #{inspect(key)} data: #{inspect(payload)}")
acc

%{payload: {:authentication_sasl_password, methods_b}}, {ps, _} ->
nonce =
case Server.decode_string(methods_b) do
Expand Down
32 changes: 32 additions & 0 deletions lib/supavisor/handler_helpers.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defmodule Supavisor.HandlerHelpers do
@moduledoc false

alias Phoenix.PubSub
alias Supavisor, as: S
alias Supavisor.Protocol.Server

Expand All @@ -9,6 +10,13 @@ defmodule Supavisor.HandlerHelpers do
mod.send(sock, data)
end

@spec sock_close(nil | S.sock()) :: :ok | {:error, term()}
def sock_close(nil), do: :ok

def sock_close({mod, sock}) do
mod.close(sock)
end

@spec setopts(S.sock(), term()) :: :ok | {:error, term()}
def setopts({mod, sock}, opts) do
mod = if mod == :gen_tcp, do: :inet, else: mod
Expand Down Expand Up @@ -87,4 +95,28 @@ defmodule Supavisor.HandlerHelpers do
{name, external_id}
end
end

@spec send_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term}
def send_cancel_query(pid, key) do
PubSub.broadcast(
Supavisor.PubSub,
"cancel_req:#{pid}_#{key}",
:cancel_query
)
end

@spec listen_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term}
def listen_cancel_query(pid, key) do
PubSub.subscribe(Supavisor.PubSub, "cancel_req:#{pid}_#{key}")
end

@spec cancel_query(keyword, non_neg_integer, atom, non_neg_integer, non_neg_integer) :: :ok
def cancel_query(host, port, ip_version, pid, key) do
msg = Server.cancel_message(pid, key)
opts = [:binary, {:packet, :raw}, {:active, true}, ip_version]
{:ok, sock} = :gen_tcp.connect(host, port, opts)
sock = {:gen_tcp, sock}
:ok = sock_send(sock, msg)
:ok = sock_close(sock)
end
end
70 changes: 57 additions & 13 deletions lib/supavisor/native_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ defmodule Supavisor.NativeHandler do
trans: trans,
acc: nil,
status: :startup,
ssl: false
ssl: false,
db_auth: nil,
backend_key: nil
}

:gen_server.enter_loop(__MODULE__, [hibernate_after: 5_000], state)
Expand All @@ -46,6 +48,16 @@ defmodule Supavisor.NativeHandler do
{:stop, :normal, state}
end

def handle_info(
{:tcp, sock, <<16::32, 1234::16, 5678::16, pid::32, key::32>>},
%{status: :startup, client_sock: {_, sock} = client_sock} = state
) do
Logger.debug("Got cancel query for #{inspect({pid, key})}")
:ok = HH.send_cancel_query(pid, key)
:ok = HH.sock_close(client_sock)
{:stop, :normal, state}
end

# ssl request from client
def handle_info(
{:tcp, sock, <<_::64>>} = _msg,
Expand Down Expand Up @@ -85,6 +97,28 @@ defmodule Supavisor.NativeHandler do
end

# send packets to client from db
def handle_info(
{_, sock, bin},
%{db_sock: {_, sock}, backend_key: nil} = state
) do
state =
bin
|> Server.decode()
|> Enum.filter(fn e -> Map.get(e, :tag) == :backend_key_data end)
|> case do
[%{payload: %{key: key, pid: pid} = k}] ->
Logger.debug("Backend key: #{inspect(k)}")
:ok = HH.listen_cancel_query(pid, key)
%{state | backend_key: k}

_ ->
state
end

:ok = HH.sock_send(state.client_sock, bin)
{:noreply, state}
end

def handle_info({_, sock, bin}, %{db_sock: {_, sock}} = state) do
:ok = HH.sock_send(state.client_sock, bin)
{:noreply, state}
Expand Down Expand Up @@ -115,9 +149,13 @@ defmodule Supavisor.NativeHandler do
end
|> Server.encode_startup_packet()

case connect_local(host, port, payload, state.ssl) do
ip_ver = H.detect_ip_version(host)
host = String.to_charlist(host)

case connect_local(host, port, payload, ip_ver, state.ssl) do
{:ok, db_sock} ->
{:noreply, %{state | db_sock: db_sock}}
auth = %{host: host, port: port, ip_ver: ip_ver}
{:noreply, %{state | db_sock: db_sock, db_auth: auth}}

{:error, reason} ->
Logger.error("Error connecting to tenant db: #{inspect(reason)}")
Expand All @@ -141,35 +179,41 @@ defmodule Supavisor.NativeHandler do
{:noreply, state}
end

def handle_info({:tcp_closed, _} = msg, state) do
Logger.debug("Terminating #{inspect(msg, pretty: true)}")
def handle_info({closed, _} = msg, state) when closed in [:tcp_closed, :ssl_closed] do
Logger.debug("Closed socket #{inspect(msg, pretty: true)}")
{:stop, :normal, state}
end

def handle_info({:ssl_closed, _} = msg, state) do
Logger.debug("Terminating #{inspect(msg, pretty: true)}")
{:stop, :normal, state}
def handle_info(:cancel_query, %{backend_key: key, db_auth: auth} = state) do
Logger.debug("Cancel query for #{inspect(key)}")
:ok = HH.cancel_query(auth.host, auth.port, auth.ip_ver, key.pid, key.key)
{:noreply, state}
end

def handle_info(msg, state) do
Logger.error("Undefined message #{inspect(msg, pretty: true)}")
{:noreply, state}
end

@impl true
def terminate(_reason, state) do
Logger.debug("Terminate #{inspect(self())}")
:ok = HH.sock_close(state.db_sock)
:ok = HH.sock_close(state.client_sock)
end

### Internal functions

@spec connect_local(String.t(), non_neg_integer, binary, boolean) ::
@spec connect_local(keyword, non_neg_integer, binary, atom, boolean) ::
{:ok, S.sock()} | {:error, term()}
defp connect_local(host, port, payload, ssl?) do
defp connect_local(host, port, payload, ip_ver, ssl?) do
sock_opts = [
:binary,
{:packet, :raw},
{:active, false},
H.detect_ip_version(host)
ip_ver
]

host = String.to_charlist(host)

with {:ok, sock} <- :gen_tcp.connect(host, port, sock_opts),
{:ok, sock} <- HH.try_ssl_handshake({:gen_tcp, sock}, ssl?),
:ok <- HH.sock_send(sock, payload) do
Expand Down
25 changes: 13 additions & 12 deletions lib/supavisor/protocol/server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Supavisor.Protocol.Server do
@ready_for_query <<?Z, 5::32, ?I>>
@ssl_request <<8::32, 1234::16, 5679::16>>
@scram_request <<?R, 23::32, 10::32, "SCRAM-SHA-256", 0, 0>>
@msg_cancel_header <<16::32, 1234::16, 5678::16>>

defmodule Pkt do
@moduledoc "Representing a packet structure with tag, length, and payload fields."
Expand Down Expand Up @@ -133,8 +134,8 @@ defmodule Supavisor.Protocol.Server do
end
end

def decode_payload(:backend_key_data, <<proc_id::integer-32, secret::integer-32>>) do
%{procid: proc_id, secret: secret}
def decode_payload(:backend_key_data, <<pid::integer-32, key::integer-32>>) do
%{pid: pid, key: key}
end

def decode_payload(:ready_for_query, payload) do
Expand Down Expand Up @@ -360,18 +361,13 @@ defmodule Supavisor.Protocol.Server do
[<<?S, len::integer-32>>, payload]
end

@spec backend_key_data() :: iodata()
@spec backend_key_data() :: {iodata(), binary}
def backend_key_data() do
procid = System.unique_integer([:positive, :monotonic])
secret = Enum.random(0..9_999_999_999)
payload = <<procid::integer-32, secret::integer-32>>
pid = System.unique_integer([:positive, :monotonic])
key = :crypto.strong_rand_bytes(4)
payload = <<pid::integer-32, key::binary>>
len = IO.iodata_length(payload) + 4
[<<?K, len::32>>, payload]
end

@spec greetings(iodata()) :: iodata()
def greetings(ps) do
[ps, backend_key_data(), @ready_for_query]
{<<?K, len::32>>, payload}
end

@spec ready_for_query() :: binary()
Expand Down Expand Up @@ -444,4 +440,9 @@ defmodule Supavisor.Protocol.Server do

<<byte_size(bin) + 9::32, 0, 3, 0, 0, bin::binary, 0>>
end

@spec cancel_message(non_neg_integer, non_neg_integer) :: iodata
def cancel_message(pid, key) do
[@msg_cancel_header, <<pid::32, key::32>>]
end
end
Loading

0 comments on commit c6e7335

Please sign in to comment.