Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle cancelling requests #194

Merged
merged 4 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.21
0.9.22
15 changes: 8 additions & 7 deletions lib/supavisor.ex
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ defmodule Supavisor do

@registry Supavisor.Registry.Tenants

@spec start(id, secrets) :: {:ok, pid} | {:error, any}
def start(id, secrets) do
@spec start(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any}
def start(id, secrets, db_name) do
case get_global_sup(id) do
nil ->
start_local_pool(id, secrets)
start_local_pool(id, secrets, db_name)

pid ->
{:ok, pid}
Expand Down Expand Up @@ -156,8 +156,8 @@ defmodule Supavisor do

## Internal functions

@spec start_local_pool(id, secrets) :: {:ok, pid} | {:error, any}
defp start_local_pool({tenant, user, mode} = id, {method, secrets}) do
@spec start_local_pool(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any}
defp start_local_pool({tenant, user, mode} = id, {method, secrets}, db_name) do
Logger.debug("Starting pool for #{inspect(id)}")

case Tenants.get_pool_config(tenant, secrets.().alias) do
Expand Down Expand Up @@ -194,7 +194,7 @@ defmodule Supavisor do
host: String.to_charlist(db_host),
port: db_port,
user: db_user,
database: db_database,
database: if(db_name != nil, do: db_name, else: db_database),
password: fn -> db_pass end,
application_name: "supavisor",
ip_version: H.ip_version(ip_ver, db_host),
Expand Down Expand Up @@ -235,7 +235,8 @@ defmodule Supavisor do
end
end

@spec set_parameter_status(id, [{binary, binary}]) :: :ok | {:error, :not_found}
@spec set_parameter_status(id, [{binary, binary}]) ::
:ok | {:error, :not_found}
def set_parameter_status(id, ps) do
case get_local_manager(id) do
nil -> {:error, :not_found}
Expand Down
1 change: 1 addition & 0 deletions lib/supavisor/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ defmodule Supavisor.Application do
PromEx,
{Registry, keys: :unique, name: Supavisor.Registry.Tenants},
{Registry, keys: :unique, name: Supavisor.Registry.ManagerTables},
{Registry, keys: :unique, name: Supavisor.Registry.PoolPids},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantSups},
{Registry, keys: :duplicate, name: Supavisor.Registry.TenantClients},
{Cluster.Supervisor, [topologies, [name: Supavisor.ClusterSupervisor]]},
Expand Down
47 changes: 44 additions & 3 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ defmodule Supavisor.ClientHandler do
proxy_type: nil,
mode: opts.mode,
stats: %{},
idle_timeout: 0
idle_timeout: 0,
db_name: nil
}

:gen_statem.enter_loop(__MODULE__, [hibernate_after: 5_000], :exchange, data)
Expand All @@ -70,6 +71,31 @@ defmodule Supavisor.ClientHandler do
{:stop, :normal, data}
end

# cancel request
def handle_event(:info, {_, _, <<16::32, 1234::16, 5678::16, pid::32, key::32>>}, _, data) do
Logger.debug("Got cancel query for #{inspect({pid, key})}")
:ok = HH.send_cancel_query(pid, key)
{:stop, :normal, data}
end

# send cancel request to db
def handle_event(:info, :cancel_query, :busy, data) do
key = {data.tenant, data.db_pid}
Logger.debug("Cancel query for #{inspect(key)}")

db_pid = data.db_pid

case db_pid_meta(key) do
[{^db_pid, meta}] ->
:ok = HH.cancel_query(meta.host, meta.port, meta.ip_ver, meta.pid, meta.key)

error ->
Logger.error("Received cancel but no proc was found #{inspect(key)} #{inspect(error)}")
end

:keep_state_and_data
end

def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %{sock: sock} = data) do
Logger.debug("Client is trying to connect with SSL")

Expand Down Expand Up @@ -109,6 +135,7 @@ defmodule Supavisor.ClientHandler do
Logger.debug("Client startup message: #{inspect(hello)}")
{user, external_id} = HH.parse_user_info(hello.payload)
Logger.metadata(project: external_id, user: user, mode: data.mode)
data = %{data | db_name: hello.payload["database"]}
{:keep_state, data, {:next_event, :internal, {:hello, {user, external_id}}}}

{:error, error} ->
Expand Down Expand Up @@ -192,7 +219,7 @@ defmodule Supavisor.ClientHandler do
def handle_event(:internal, :subscribe, _, data) do
Logger.debug("Subscribe to tenant #{inspect(data.id)}")

with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets),
with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets, data.db_name),
{:ok, opts} <- Supavisor.subscribe(sup, data.id) do
Process.monitor(opts.workers.manager)
data = Map.merge(data, opts.workers)
Expand Down Expand Up @@ -221,7 +248,10 @@ defmodule Supavisor.ClientHandler do
end

def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do
:ok = HH.sock_send(sock, Server.greetings(ps))
{header, <<pid::32, key::32>> = payload} = Server.backend_key_data()
msg = [ps, [header, payload], Server.ready_for_query()]
:ok = HH.listen_cancel_query(pid, key)
:ok = HH.sock_send(sock, msg)

if data.idle_timeout > 0 do
{:next_state, :idle, data, idle_check(data.idle_timeout)}
Expand Down Expand Up @@ -632,4 +662,15 @@ defmodule Supavisor.ClientHandler do
defp idle_check(timeout) do
{:timeout, timeout, :idle_terminate}
end

defp db_pid_meta({_, pid} = key) do
rkey = Supavisor.Registry.PoolPids
fnode = node(pid)

if fnode == node() do
Registry.lookup(rkey, key)
else
:erpc.call(fnode, Registry, :lookup, [rkey, key], 15_000)
end
end
end
7 changes: 7 additions & 0 deletions lib/supavisor/db_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@

case try_ssl_handshake({:gen_tcp, sock}, auth) do
{:ok, sock} ->
case send_startup(sock, auth) do

Check warning on line 75 in lib/supavisor/db_handler.ex

View workflow job for this annotation

GitHub Actions / Formatting Checks

Function body is nested too deep (max depth is 2, was 3).
:ok ->
:ok = activate(sock)
{:next_state, :authentication, %{data | sock: sock}}
Expand Down Expand Up @@ -110,6 +110,13 @@
%{tag: :ready_for_query, payload: db_state}, {ps, _} ->
{ps, db_state}

%{tag: :backend_key_data, payload: payload}, acc ->
key = {data.tenant, self()}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may not need to have self() as part of the key. Instead you can just use data.tenant and use Registry.values to look up the values for a given key-pid pair.

conn = %{host: data.auth.host, port: data.auth.port, ip_ver: data.auth.ip_version}
Registry.register(Supavisor.Registry.PoolPids, key, Map.merge(payload, conn))
Logger.debug("Backend #{inspect(key)} data: #{inspect(payload)}")
acc

%{payload: {:authentication_sasl_password, methods_b}}, {ps, _} ->
nonce =
case Server.decode_string(methods_b) do
Expand Down
32 changes: 32 additions & 0 deletions lib/supavisor/handler_helpers.ex
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
defmodule Supavisor.HandlerHelpers do
@moduledoc false

alias Phoenix.PubSub
alias Supavisor, as: S
alias Supavisor.Protocol.Server

Expand All @@ -9,6 +10,13 @@ defmodule Supavisor.HandlerHelpers do
mod.send(sock, data)
end

@spec sock_close(nil | S.sock()) :: :ok | {:error, term()}
def sock_close(nil), do: :ok

def sock_close({mod, sock}) do
mod.close(sock)
end

@spec setopts(S.sock(), term()) :: :ok | {:error, term()}
def setopts({mod, sock}, opts) do
mod = if mod == :gen_tcp, do: :inet, else: mod
Expand Down Expand Up @@ -87,4 +95,28 @@ defmodule Supavisor.HandlerHelpers do
{name, external_id}
end
end

@spec send_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term}
def send_cancel_query(pid, key) do
PubSub.broadcast(
Supavisor.PubSub,
"cancel_req:#{pid}_#{key}",
:cancel_query
)
end

@spec listen_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term}
def listen_cancel_query(pid, key) do
PubSub.subscribe(Supavisor.PubSub, "cancel_req:#{pid}_#{key}")
end

@spec cancel_query(keyword, non_neg_integer, atom, non_neg_integer, non_neg_integer) :: :ok
def cancel_query(host, port, ip_version, pid, key) do
msg = Server.cancel_message(pid, key)
opts = [:binary, {:packet, :raw}, {:active, true}, ip_version]
{:ok, sock} = :gen_tcp.connect(host, port, opts)
sock = {:gen_tcp, sock}
:ok = sock_send(sock, msg)
:ok = sock_close(sock)
end
end
70 changes: 57 additions & 13 deletions lib/supavisor/native_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ defmodule Supavisor.NativeHandler do
trans: trans,
acc: nil,
status: :startup,
ssl: false
ssl: false,
db_auth: nil,
backend_key: nil
}

:gen_server.enter_loop(__MODULE__, [hibernate_after: 5_000], state)
Expand All @@ -46,6 +48,16 @@ defmodule Supavisor.NativeHandler do
{:stop, :normal, state}
end

def handle_info(
{:tcp, sock, <<16::32, 1234::16, 5678::16, pid::32, key::32>>},
%{status: :startup, client_sock: {_, sock} = client_sock} = state
) do
Logger.debug("Got cancel query for #{inspect({pid, key})}")
:ok = HH.send_cancel_query(pid, key)
:ok = HH.sock_close(client_sock)
{:stop, :normal, state}
end

# ssl request from client
def handle_info(
{:tcp, sock, <<_::64>>} = _msg,
Expand Down Expand Up @@ -85,6 +97,28 @@ defmodule Supavisor.NativeHandler do
end

# send packets to client from db
def handle_info(
{_, sock, bin},
%{db_sock: {_, sock}, backend_key: nil} = state
) do
state =
bin
|> Server.decode()
|> Enum.filter(fn e -> Map.get(e, :tag) == :backend_key_data end)
|> case do
[%{payload: %{key: key, pid: pid} = k}] ->
Logger.debug("Backend key: #{inspect(k)}")
:ok = HH.listen_cancel_query(pid, key)
%{state | backend_key: k}

_ ->
state
end

:ok = HH.sock_send(state.client_sock, bin)
{:noreply, state}
end

def handle_info({_, sock, bin}, %{db_sock: {_, sock}} = state) do
:ok = HH.sock_send(state.client_sock, bin)
{:noreply, state}
Expand Down Expand Up @@ -115,9 +149,13 @@ defmodule Supavisor.NativeHandler do
end
|> Server.encode_startup_packet()

case connect_local(host, port, payload, state.ssl) do
ip_ver = H.detect_ip_version(host)
host = String.to_charlist(host)

case connect_local(host, port, payload, ip_ver, state.ssl) do
{:ok, db_sock} ->
{:noreply, %{state | db_sock: db_sock}}
auth = %{host: host, port: port, ip_ver: ip_ver}
{:noreply, %{state | db_sock: db_sock, db_auth: auth}}

{:error, reason} ->
Logger.error("Error connecting to tenant db: #{inspect(reason)}")
Expand All @@ -141,35 +179,41 @@ defmodule Supavisor.NativeHandler do
{:noreply, state}
end

def handle_info({:tcp_closed, _} = msg, state) do
Logger.debug("Terminating #{inspect(msg, pretty: true)}")
def handle_info({closed, _} = msg, state) when closed in [:tcp_closed, :ssl_closed] do
Logger.debug("Closed socket #{inspect(msg, pretty: true)}")
{:stop, :normal, state}
end

def handle_info({:ssl_closed, _} = msg, state) do
Logger.debug("Terminating #{inspect(msg, pretty: true)}")
{:stop, :normal, state}
def handle_info(:cancel_query, %{backend_key: key, db_auth: auth} = state) do
Logger.debug("Cancel query for #{inspect(key)}")
:ok = HH.cancel_query(auth.host, auth.port, auth.ip_ver, key.pid, key.key)
{:noreply, state}
end

def handle_info(msg, state) do
Logger.error("Undefined message #{inspect(msg, pretty: true)}")
{:noreply, state}
end

@impl true
def terminate(_reason, state) do
Logger.debug("Terminate #{inspect(self())}")
:ok = HH.sock_close(state.db_sock)
:ok = HH.sock_close(state.client_sock)
end

### Internal functions

@spec connect_local(String.t(), non_neg_integer, binary, boolean) ::
@spec connect_local(keyword, non_neg_integer, binary, atom, boolean) ::
{:ok, S.sock()} | {:error, term()}
defp connect_local(host, port, payload, ssl?) do
defp connect_local(host, port, payload, ip_ver, ssl?) do
sock_opts = [
:binary,
{:packet, :raw},
{:active, false},
H.detect_ip_version(host)
ip_ver
]

host = String.to_charlist(host)

with {:ok, sock} <- :gen_tcp.connect(host, port, sock_opts),
{:ok, sock} <- HH.try_ssl_handshake({:gen_tcp, sock}, ssl?),
:ok <- HH.sock_send(sock, payload) do
Expand Down
25 changes: 13 additions & 12 deletions lib/supavisor/protocol/server.ex
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ defmodule Supavisor.Protocol.Server do
@ready_for_query <<?Z, 5::32, ?I>>
@ssl_request <<8::32, 1234::16, 5679::16>>
@scram_request <<?R, 23::32, 10::32, "SCRAM-SHA-256", 0, 0>>
@msg_cancel_header <<16::32, 1234::16, 5678::16>>

defmodule Pkt do
@moduledoc "Representing a packet structure with tag, length, and payload fields."
Expand Down Expand Up @@ -133,8 +134,8 @@ defmodule Supavisor.Protocol.Server do
end
end

def decode_payload(:backend_key_data, <<proc_id::integer-32, secret::integer-32>>) do
%{procid: proc_id, secret: secret}
def decode_payload(:backend_key_data, <<pid::integer-32, key::integer-32>>) do
%{pid: pid, key: key}
end

def decode_payload(:ready_for_query, payload) do
Expand Down Expand Up @@ -360,18 +361,13 @@ defmodule Supavisor.Protocol.Server do
[<<?S, len::integer-32>>, payload]
end

@spec backend_key_data() :: iodata()
@spec backend_key_data() :: {iodata(), binary}
def backend_key_data() do
procid = System.unique_integer([:positive, :monotonic])
secret = Enum.random(0..9_999_999_999)
payload = <<procid::integer-32, secret::integer-32>>
pid = System.unique_integer([:positive, :monotonic])
key = :crypto.strong_rand_bytes(4)
payload = <<pid::integer-32, key::binary>>
len = IO.iodata_length(payload) + 4
[<<?K, len::32>>, payload]
end

@spec greetings(iodata()) :: iodata()
def greetings(ps) do
[ps, backend_key_data(), @ready_for_query]
{<<?K, len::32>>, payload}
end

@spec ready_for_query() :: binary()
Expand Down Expand Up @@ -444,4 +440,9 @@ defmodule Supavisor.Protocol.Server do

<<byte_size(bin) + 9::32, 0, 3, 0, 0, bin::binary, 0>>
end

@spec cancel_message(non_neg_integer, non_neg_integer) :: iodata
def cancel_message(pid, key) do
[@msg_cancel_header, <<pid::32, key::32>>]
end
end
Loading
Loading