Skip to content

Commit

Permalink
Validate startup packets
Browse files Browse the repository at this point in the history
Supavisor's ClientHandler can receive a lot of junk, as any client can attempt to connect to it. This change filters out the noise from unserious clients (like bots/scrapers) sending requests.
  • Loading branch information
acco committed Jul 6, 2023
1 parent 313e648 commit 16be8bd
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 21 deletions.
72 changes: 51 additions & 21 deletions lib/supavisor/client_handler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,27 @@ defmodule Supavisor.ClientHandler do
end

def handle_event(:info, {:tcp, _, bin}, :exchange, %{socket: socket} = data) do
hello = decode_startup_packet(bin)
Logger.warning("Client startup message: #{inspect(hello)}")
{user, external_id} = parse_user_info(hello.payload["user"])
Logger.metadata(project: external_id, user: user)
with {:ok, hello} <- decode_startup_packet(bin) do
Logger.warning("Client startup message: #{inspect(hello)}")
{user, external_id} = parse_user_info(hello.payload["user"])
Logger.metadata(project: external_id, user: user)

case Tenants.get_user(external_id, user) do
{:ok, user_info} ->
new_data = update_user_data(data, external_id, user_info)
case Tenants.get_user(external_id, user) do
{:ok, user_info} ->
new_data = update_user_data(data, external_id, user_info)

{:keep_state, new_data,
{:next_event, :internal, {:handle, fn -> user_info.db_password end}}}
{:keep_state, new_data,
{:next_event, :internal, {:handle, fn -> user_info.db_password end}}}

{:error, reason} ->
Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}")
Server.send_error(socket, "XX000", "Tenant or user not found")
{:stop, :normal, data}
{:error, reason} ->
Logger.error("User not found: #{inspect(reason)} #{inspect({user, external_id})}")
Server.send_error(socket, "XX000", "Tenant or user not found")
{:stop, :normal, data}
else
{:error, :bad_startup_payload} ->
Logger.warn("Bad startup packet received", bin: bin)
{:stop, :normal, data}
end
end
end

Expand Down Expand Up @@ -295,20 +300,45 @@ defmodule Supavisor.ClientHandler do
end

def decode_startup_packet(<<len::integer-32, _protocol::binary-4, rest::binary>>) do
%{
len: len,
payload:
String.split(rest, <<0>>, trim: true)
|> Enum.chunk_every(2)
|> Enum.into(%{}, fn [k, v] -> {k, v} end),
tag: :startup
}
with {:ok, payload} <- decode_startup_packet_payload(rest) do
pkt = %{
len: len,
payload: payload,
tag: :startup
}

{:ok, pkt}
end
end

def decode_startup_packet(_) do
:undef
end

# The startup packet payload is a list of key/value pairs, separated by null bytes
defp decode_startup_packet_payload(payload) do
fields = String.split(payload, <<0>>, trim: true)

# If the number of fields is odd, then the payload is malformed
if rem(length(fields), 2) == 1 do
{:error, :bad_startup_payload}
else
map =
fields
|> Enum.chunk_every(2)
|> Enum.map(fn [k, v] -> {k, v} end)
|> Map.new()

# We only do light validation on the fields in the payload. The only field we use at the
# moment is `user`. If that's missing, this is a bad payload.
if Map.has_key?(map, "user") do
{:ok, map}
else
{:error, :bad_startup_payload}
end
end
end

@spec handle_exchange(port, fun) :: :ok | {:error, String.t()}
def handle_exchange(socket, password) do
:ok = Server.send_request_authentication(socket)
Expand Down
24 changes: 24 additions & 0 deletions test/supavisor/client_handler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,28 @@ defmodule Supavisor.ClientHandlerTest do
assert external_id == "external_id"
end
end

describe "decode_startup_packet/1" do
test "handles bad startup packets" do
packet = <<0, 0, 0, 8, 0, 0, 0, 0, 3>>
assert {:error, _} = ClientHandler.decode_startup_packet(packet)
end

test "handles valid startup packets" do
payload = %{
"DateStyle" => "ISO",
"TimeZone" => "Asia/Tokyo",
"client_encoding" => "UTF8",
"database" => "mydbname",
"extra_float_digits" => "2",
"user" => "tenant.mytenant"
}

fields = Enum.reduce(payload, [], fn {k, v}, acc -> [k, v | acc] end) |> Enum.join(<<0>>)
len = String.length(fields) + 4
packet = <<len::integer-32, "prot"::binary, fields::binary>>
assert {:ok, hello} = ClientHandler.decode_startup_packet(packet)
assert hello[:payload]["user"] == "tenant.mytenant"
end
end
end

0 comments on commit 16be8bd

Please sign in to comment.