From c6e7335eed1cf8fcbad6b40e31505de6f8051131 Mon Sep 17 00:00:00 2001 From: Stas Date: Tue, 24 Oct 2023 10:14:15 +0200 Subject: [PATCH] feat: handle cancelling requests (#194) --- VERSION | 2 +- lib/supavisor.ex | 15 ++++--- lib/supavisor/application.ex | 1 + lib/supavisor/client_handler.ex | 47 +++++++++++++++++-- lib/supavisor/db_handler.ex | 7 +++ lib/supavisor/handler_helpers.ex | 32 +++++++++++++ lib/supavisor/native_handler.ex | 70 +++++++++++++++++++++++------ lib/supavisor/protocol/server.ex | 25 ++++++----- test/supavisor/protocol_test.exs | 36 ++++++--------- test/supavisor/syn_handler_test.exs | 4 +- test/support/fixtures/helpers.ex | 4 +- 11 files changed, 180 insertions(+), 63 deletions(-) diff --git a/VERSION b/VERSION index b2160230..ea3f0d7a 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.9.21 +0.9.22 diff --git a/lib/supavisor.ex b/lib/supavisor.ex index 2b5b822a..624bff7e 100644 --- a/lib/supavisor.ex +++ b/lib/supavisor.ex @@ -15,11 +15,11 @@ defmodule Supavisor do @registry Supavisor.Registry.Tenants - @spec start(id, secrets) :: {:ok, pid} | {:error, any} - def start(id, secrets) do + @spec start(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any} + def start(id, secrets, db_name) do case get_global_sup(id) do nil -> - start_local_pool(id, secrets) + start_local_pool(id, secrets, db_name) pid -> {:ok, pid} @@ -156,8 +156,8 @@ defmodule Supavisor do ## Internal functions - @spec start_local_pool(id, secrets) :: {:ok, pid} | {:error, any} - defp start_local_pool({tenant, user, mode} = id, {method, secrets}) do + @spec start_local_pool(id, secrets, String.t() | nil) :: {:ok, pid} | {:error, any} + defp start_local_pool({tenant, user, mode} = id, {method, secrets}, db_name) do Logger.debug("Starting pool for #{inspect(id)}") case Tenants.get_pool_config(tenant, secrets.().alias) do @@ -194,7 +194,7 @@ defmodule Supavisor do host: String.to_charlist(db_host), port: db_port, user: db_user, - database: db_database, + database: if(db_name != nil, do: db_name, else: db_database), password: fn -> db_pass end, application_name: "supavisor", ip_version: H.ip_version(ip_ver, db_host), @@ -235,7 +235,8 @@ defmodule Supavisor do end end - @spec set_parameter_status(id, [{binary, binary}]) :: :ok | {:error, :not_found} + @spec set_parameter_status(id, [{binary, binary}]) :: + :ok | {:error, :not_found} def set_parameter_status(id, ps) do case get_local_manager(id) do nil -> {:error, :not_found} diff --git a/lib/supavisor/application.ex b/lib/supavisor/application.ex index 4b25aa53..91b33b08 100644 --- a/lib/supavisor/application.ex +++ b/lib/supavisor/application.ex @@ -61,6 +61,7 @@ defmodule Supavisor.Application do PromEx, {Registry, keys: :unique, name: Supavisor.Registry.Tenants}, {Registry, keys: :unique, name: Supavisor.Registry.ManagerTables}, + {Registry, keys: :unique, name: Supavisor.Registry.PoolPids}, {Registry, keys: :duplicate, name: Supavisor.Registry.TenantSups}, {Registry, keys: :duplicate, name: Supavisor.Registry.TenantClients}, {Cluster.Supervisor, [topologies, [name: Supavisor.ClusterSupervisor]]}, diff --git a/lib/supavisor/client_handler.ex b/lib/supavisor/client_handler.ex index 6d330124..311df525 100644 --- a/lib/supavisor/client_handler.ex +++ b/lib/supavisor/client_handler.ex @@ -57,7 +57,8 @@ defmodule Supavisor.ClientHandler do proxy_type: nil, mode: opts.mode, stats: %{}, - idle_timeout: 0 + idle_timeout: 0, + db_name: nil } :gen_statem.enter_loop(__MODULE__, [hibernate_after: 5_000], :exchange, data) @@ -70,6 +71,31 @@ defmodule Supavisor.ClientHandler do {:stop, :normal, data} end + # cancel request + def handle_event(:info, {_, _, <<16::32, 1234::16, 5678::16, pid::32, key::32>>}, _, data) do + Logger.debug("Got cancel query for #{inspect({pid, key})}") + :ok = HH.send_cancel_query(pid, key) + {:stop, :normal, data} + end + + # send cancel request to db + def handle_event(:info, :cancel_query, :busy, data) do + key = {data.tenant, data.db_pid} + Logger.debug("Cancel query for #{inspect(key)}") + + db_pid = data.db_pid + + case db_pid_meta(key) do + [{^db_pid, meta}] -> + :ok = HH.cancel_query(meta.host, meta.port, meta.ip_ver, meta.pid, meta.key) + + error -> + Logger.error("Received cancel but no proc was found #{inspect(key)} #{inspect(error)}") + end + + :keep_state_and_data + end + def handle_event(:info, {:tcp, _, <<_::64>>}, :exchange, %{sock: sock} = data) do Logger.debug("Client is trying to connect with SSL") @@ -109,6 +135,7 @@ defmodule Supavisor.ClientHandler do Logger.debug("Client startup message: #{inspect(hello)}") {user, external_id} = HH.parse_user_info(hello.payload) Logger.metadata(project: external_id, user: user, mode: data.mode) + data = %{data | db_name: hello.payload["database"]} {:keep_state, data, {:next_event, :internal, {:hello, {user, external_id}}}} {:error, error} -> @@ -192,7 +219,7 @@ defmodule Supavisor.ClientHandler do def handle_event(:internal, :subscribe, _, data) do Logger.debug("Subscribe to tenant #{inspect(data.id)}") - with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets), + with {:ok, sup} <- Supavisor.start(data.id, data.auth_secrets, data.db_name), {:ok, opts} <- Supavisor.subscribe(sup, data.id) do Process.monitor(opts.workers.manager) data = Map.merge(data, opts.workers) @@ -221,7 +248,10 @@ defmodule Supavisor.ClientHandler do end def handle_event(:internal, {:greetings, ps}, _, %{sock: sock} = data) do - :ok = HH.sock_send(sock, Server.greetings(ps)) + {header, <> = payload} = Server.backend_key_data() + msg = [ps, [header, payload], Server.ready_for_query()] + :ok = HH.listen_cancel_query(pid, key) + :ok = HH.sock_send(sock, msg) if data.idle_timeout > 0 do {:next_state, :idle, data, idle_check(data.idle_timeout)} @@ -632,4 +662,15 @@ defmodule Supavisor.ClientHandler do defp idle_check(timeout) do {:timeout, timeout, :idle_terminate} end + + defp db_pid_meta({_, pid} = key) do + rkey = Supavisor.Registry.PoolPids + fnode = node(pid) + + if fnode == node() do + Registry.lookup(rkey, key) + else + :erpc.call(fnode, Registry, :lookup, [rkey, key], 15_000) + end + end end diff --git a/lib/supavisor/db_handler.ex b/lib/supavisor/db_handler.ex index 1484284c..efee8abe 100644 --- a/lib/supavisor/db_handler.ex +++ b/lib/supavisor/db_handler.ex @@ -110,6 +110,13 @@ defmodule Supavisor.DbHandler do %{tag: :ready_for_query, payload: db_state}, {ps, _} -> {ps, db_state} + %{tag: :backend_key_data, payload: payload}, acc -> + key = {data.tenant, self()} + conn = %{host: data.auth.host, port: data.auth.port, ip_ver: data.auth.ip_version} + Registry.register(Supavisor.Registry.PoolPids, key, Map.merge(payload, conn)) + Logger.debug("Backend #{inspect(key)} data: #{inspect(payload)}") + acc + %{payload: {:authentication_sasl_password, methods_b}}, {ps, _} -> nonce = case Server.decode_string(methods_b) do diff --git a/lib/supavisor/handler_helpers.ex b/lib/supavisor/handler_helpers.ex index 226c2510..cef52004 100644 --- a/lib/supavisor/handler_helpers.ex +++ b/lib/supavisor/handler_helpers.ex @@ -1,6 +1,7 @@ defmodule Supavisor.HandlerHelpers do @moduledoc false + alias Phoenix.PubSub alias Supavisor, as: S alias Supavisor.Protocol.Server @@ -9,6 +10,13 @@ defmodule Supavisor.HandlerHelpers do mod.send(sock, data) end + @spec sock_close(nil | S.sock()) :: :ok | {:error, term()} + def sock_close(nil), do: :ok + + def sock_close({mod, sock}) do + mod.close(sock) + end + @spec setopts(S.sock(), term()) :: :ok | {:error, term()} def setopts({mod, sock}, opts) do mod = if mod == :gen_tcp, do: :inet, else: mod @@ -87,4 +95,28 @@ defmodule Supavisor.HandlerHelpers do {name, external_id} end end + + @spec send_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term} + def send_cancel_query(pid, key) do + PubSub.broadcast( + Supavisor.PubSub, + "cancel_req:#{pid}_#{key}", + :cancel_query + ) + end + + @spec listen_cancel_query(non_neg_integer, non_neg_integer) :: :ok | {:errr, term} + def listen_cancel_query(pid, key) do + PubSub.subscribe(Supavisor.PubSub, "cancel_req:#{pid}_#{key}") + end + + @spec cancel_query(keyword, non_neg_integer, atom, non_neg_integer, non_neg_integer) :: :ok + def cancel_query(host, port, ip_version, pid, key) do + msg = Server.cancel_message(pid, key) + opts = [:binary, {:packet, :raw}, {:active, true}, ip_version] + {:ok, sock} = :gen_tcp.connect(host, port, opts) + sock = {:gen_tcp, sock} + :ok = sock_send(sock, msg) + :ok = sock_close(sock) + end end diff --git a/lib/supavisor/native_handler.ex b/lib/supavisor/native_handler.ex index 66b76f68..b42f1c73 100644 --- a/lib/supavisor/native_handler.ex +++ b/lib/supavisor/native_handler.ex @@ -32,7 +32,9 @@ defmodule Supavisor.NativeHandler do trans: trans, acc: nil, status: :startup, - ssl: false + ssl: false, + db_auth: nil, + backend_key: nil } :gen_server.enter_loop(__MODULE__, [hibernate_after: 5_000], state) @@ -46,6 +48,16 @@ defmodule Supavisor.NativeHandler do {:stop, :normal, state} end + def handle_info( + {:tcp, sock, <<16::32, 1234::16, 5678::16, pid::32, key::32>>}, + %{status: :startup, client_sock: {_, sock} = client_sock} = state + ) do + Logger.debug("Got cancel query for #{inspect({pid, key})}") + :ok = HH.send_cancel_query(pid, key) + :ok = HH.sock_close(client_sock) + {:stop, :normal, state} + end + # ssl request from client def handle_info( {:tcp, sock, <<_::64>>} = _msg, @@ -85,6 +97,28 @@ defmodule Supavisor.NativeHandler do end # send packets to client from db + def handle_info( + {_, sock, bin}, + %{db_sock: {_, sock}, backend_key: nil} = state + ) do + state = + bin + |> Server.decode() + |> Enum.filter(fn e -> Map.get(e, :tag) == :backend_key_data end) + |> case do + [%{payload: %{key: key, pid: pid} = k}] -> + Logger.debug("Backend key: #{inspect(k)}") + :ok = HH.listen_cancel_query(pid, key) + %{state | backend_key: k} + + _ -> + state + end + + :ok = HH.sock_send(state.client_sock, bin) + {:noreply, state} + end + def handle_info({_, sock, bin}, %{db_sock: {_, sock}} = state) do :ok = HH.sock_send(state.client_sock, bin) {:noreply, state} @@ -115,9 +149,13 @@ defmodule Supavisor.NativeHandler do end |> Server.encode_startup_packet() - case connect_local(host, port, payload, state.ssl) do + ip_ver = H.detect_ip_version(host) + host = String.to_charlist(host) + + case connect_local(host, port, payload, ip_ver, state.ssl) do {:ok, db_sock} -> - {:noreply, %{state | db_sock: db_sock}} + auth = %{host: host, port: port, ip_ver: ip_ver} + {:noreply, %{state | db_sock: db_sock, db_auth: auth}} {:error, reason} -> Logger.error("Error connecting to tenant db: #{inspect(reason)}") @@ -141,14 +179,15 @@ defmodule Supavisor.NativeHandler do {:noreply, state} end - def handle_info({:tcp_closed, _} = msg, state) do - Logger.debug("Terminating #{inspect(msg, pretty: true)}") + def handle_info({closed, _} = msg, state) when closed in [:tcp_closed, :ssl_closed] do + Logger.debug("Closed socket #{inspect(msg, pretty: true)}") {:stop, :normal, state} end - def handle_info({:ssl_closed, _} = msg, state) do - Logger.debug("Terminating #{inspect(msg, pretty: true)}") - {:stop, :normal, state} + def handle_info(:cancel_query, %{backend_key: key, db_auth: auth} = state) do + Logger.debug("Cancel query for #{inspect(key)}") + :ok = HH.cancel_query(auth.host, auth.port, auth.ip_ver, key.pid, key.key) + {:noreply, state} end def handle_info(msg, state) do @@ -156,20 +195,25 @@ defmodule Supavisor.NativeHandler do {:noreply, state} end + @impl true + def terminate(_reason, state) do + Logger.debug("Terminate #{inspect(self())}") + :ok = HH.sock_close(state.db_sock) + :ok = HH.sock_close(state.client_sock) + end + ### Internal functions - @spec connect_local(String.t(), non_neg_integer, binary, boolean) :: + @spec connect_local(keyword, non_neg_integer, binary, atom, boolean) :: {:ok, S.sock()} | {:error, term()} - defp connect_local(host, port, payload, ssl?) do + defp connect_local(host, port, payload, ip_ver, ssl?) do sock_opts = [ :binary, {:packet, :raw}, {:active, false}, - H.detect_ip_version(host) + ip_ver ] - host = String.to_charlist(host) - with {:ok, sock} <- :gen_tcp.connect(host, port, sock_opts), {:ok, sock} <- HH.try_ssl_handshake({:gen_tcp, sock}, ssl?), :ok <- HH.sock_send(sock, payload) do diff --git a/lib/supavisor/protocol/server.ex b/lib/supavisor/protocol/server.ex index bcb25a18..96c04e38 100644 --- a/lib/supavisor/protocol/server.ex +++ b/lib/supavisor/protocol/server.ex @@ -13,6 +13,7 @@ defmodule Supavisor.Protocol.Server do @ready_for_query <> @ssl_request <<8::32, 1234::16, 5679::16>> @scram_request <> + @msg_cancel_header <<16::32, 1234::16, 5678::16>> defmodule Pkt do @moduledoc "Representing a packet structure with tag, length, and payload fields." @@ -133,8 +134,8 @@ defmodule Supavisor.Protocol.Server do end end - def decode_payload(:backend_key_data, <>) do - %{procid: proc_id, secret: secret} + def decode_payload(:backend_key_data, <>) do + %{pid: pid, key: key} end def decode_payload(:ready_for_query, payload) do @@ -360,18 +361,13 @@ defmodule Supavisor.Protocol.Server do [<>, payload] end - @spec backend_key_data() :: iodata() + @spec backend_key_data() :: {iodata(), binary} def backend_key_data() do - procid = System.unique_integer([:positive, :monotonic]) - secret = Enum.random(0..9_999_999_999) - payload = <> + pid = System.unique_integer([:positive, :monotonic]) + key = :crypto.strong_rand_bytes(4) + payload = <> len = IO.iodata_length(payload) + 4 - [<>, payload] - end - - @spec greetings(iodata()) :: iodata() - def greetings(ps) do - [ps, backend_key_data(), @ready_for_query] + {<>, payload} end @spec ready_for_query() :: binary() @@ -444,4 +440,9 @@ defmodule Supavisor.Protocol.Server do <> end + + @spec cancel_message(non_neg_integer, non_neg_integer) :: iodata + def cancel_message(pid, key) do + [@msg_cancel_header, <>] + end end diff --git a/test/supavisor/protocol_test.exs b/test/supavisor/protocol_test.exs index 128cd8b9..1fc7c00d 100644 --- a/test/supavisor/protocol_test.exs +++ b/test/supavisor/protocol_test.exs @@ -46,40 +46,21 @@ defmodule Supavisor.ProtocolTest do end test "backend_key_data/0" do - result = S.backend_key_data() - payload = Enum.at(result, 1) + {header, payload} = S.backend_key_data() len = byte_size(payload) + 4 assert [ %S.Pkt{ tag: :backend_key_data, len: 13, - payload: %{procid: _, secret: _} + payload: %{pid: _, key: _} } - ] = S.decode(result |> IO.iodata_to_binary()) + ] = S.decode([header, payload] |> IO.iodata_to_binary()) - assert hd(result) == <> + assert header == <> assert byte_size(payload) == 8 end - test "greetings/1" do - ps = S.encode_parameter_status(@initial_data) - - dec = - S.greetings(ps) - |> IO.iodata_to_binary() - |> S.decode() - - ready_for_query_pos = Enum.at(dec, -1) - backend_key_data_pos = Enum.at(dec, -2) - assert %S.Pkt{tag: :ready_for_query} = ready_for_query_pos - assert %S.Pkt{tag: :backend_key_data} = backend_key_data_pos - tags = Enum.map(dec, & &1.tag) - assert Enum.count(tags, &(&1 == :parameter_status)) == 13 - assert Enum.count(tags, &(&1 == :backend_key_data)) == 1 - assert Enum.count(tags, &(&1 == :ready_for_query)) == 1 - end - test "decode_payload for error_response" do assert S.decode(@auth_bin_error) == [ %Supavisor.Protocol.Server.Pkt{ @@ -97,4 +78,13 @@ defmodule Supavisor.ProtocolTest do } ] end + + test "cancel_message/2" do + pid = 123 + key = 123_456 + expected = <<0, 0, 0, 16, 4, 210, 22, 46, 0, 0, 0, 123, 0, 1, 226, 64>> + + assert S.cancel_message(pid, key) + |> IO.iodata_to_binary() == expected + end end diff --git a/test/supavisor/syn_handler_test.exs b/test/supavisor/syn_handler_test.exs index fd4e4a54..b6b26d47 100644 --- a/test/supavisor/syn_handler_test.exs +++ b/test/supavisor/syn_handler_test.exs @@ -11,7 +11,7 @@ defmodule Supavisor.SynHandlerTest do secret = %{alias: "postgres"} auth_secret = {:password, fn -> secret end} - {:ok, pid2} = :erpc.call(node2, Supavisor.FixturesHelpers, :start_pool, [@id, secret]) + {:ok, pid2} = :erpc.call(node2, Supavisor.FixturesHelpers, :start_pool, [@id, secret, nil]) Process.sleep(500) assert pid2 == Supavisor.get_global_sup(@id) assert node(pid2) == node2 @@ -19,7 +19,7 @@ defmodule Supavisor.SynHandlerTest do Process.sleep(500) assert nil == Supavisor.get_global_sup(@id) - {:ok, pid1} = Supavisor.start(@id, auth_secret) + {:ok, pid1} = Supavisor.start(@id, auth_secret, nil) assert pid1 == Supavisor.get_global_sup(@id) assert node(pid1) == node() diff --git a/test/support/fixtures/helpers.ex b/test/support/fixtures/helpers.ex index 0e593c96..4467361e 100644 --- a/test/support/fixtures/helpers.ex +++ b/test/support/fixtures/helpers.ex @@ -1,8 +1,8 @@ defmodule Supavisor.FixturesHelpers do @moduledoc false - def start_pool(id, secret) do + def start_pool(id, secret, db_name) do secret = {:password, fn -> secret end} - Supavisor.start(id, secret) + Supavisor.start(id, secret, db_name) end end