Skip to content

Commit

Permalink
Misc improvements (#7)
Browse files Browse the repository at this point in the history
* Extract JWT helper in test
* Add test on publish endpoint (unit and integration)
* Add some integration test on public api
* Allow to configure sse timeout and sse keepalive from a client point
of view
* Remove dead file: `neurow/lib/neurow.ex`
* Allow to tune jwt max lifetime
* Avoid Application.fetch_env in public endpoint
* Export metrics for JWT errors
* Handle shutdown properly
  • Loading branch information
bpaquet authored Jul 22, 2024
1 parent 4c1a2e9 commit a7f4c76
Show file tree
Hide file tree
Showing 16 changed files with 552 additions and 76 deletions.
9 changes: 9 additions & 0 deletions neurow/config/runtime.exs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@ config :logger, :console,

config :neurow, public_api_port: String.to_integer(System.get_env("PUBLIC_API_PORT") || "4000")

config :neurow,
public_api_jwt_max_lifetime:
String.to_integer(System.get_env("PUBLIC_API_JWT_MAX_LIFETIME") || "120")

config :neurow,
internal_api_port: String.to_integer(System.get_env("INTERNAL_API_PORT") || "3000")

config :neurow,
internal_api_jwt_max_lifetime:
String.to_integer(System.get_env("INTERNAL_API_JWT_MAX_LIFETIME") || "120")

config :neurow, sse_timeout: String.to_integer(System.get_env("SSE_TIMEOUT") || "900000")
config :neurow, sse_keepalive: String.to_integer(System.get_env("SSE_KEEPALIVE") || "600000")

config :neurow, ssl_keyfile: System.get_env("SSL_KEYFILE")
config :neurow, ssl_certfile: System.get_env("SSL_CERTFILE")
Expand Down
18 changes: 0 additions & 18 deletions neurow/lib/neurow.ex

This file was deleted.

4 changes: 3 additions & 1 deletion neurow/lib/neurow/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ defmodule Neurow.Application do
name: Neurow.PubSub, options: [adapter: Phoenix.PubSub.PG2, pool_size: 10]},
{Plug.Cowboy, scheme: :http, plug: Neurow.InternalApi, options: [port: internal_api_port]},
{Plug.Cowboy,
scheme: sse_http_scheme, plug: Neurow.PublicApi, options: public_api_http_config}
scheme: sse_http_scheme, plug: Neurow.PublicApi, options: public_api_http_config},
{Plug.Cowboy.Drainer, refs: [Neurow.PublicApi.HTTP], shutdown: 20_000},
{StopListener, []}
]

MetricsPlugExporter.setup()
Expand Down
49 changes: 42 additions & 7 deletions neurow/lib/neurow/configuration.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,39 @@ defmodule Neurow.Configuration do
end

def public_api_audience do
Application.fetch_env!(:neurow, :public_api_authentication)[:audience]
GenServer.call(__MODULE__, {:static_param, :public_api_audience})
end

def public_api_verbose_authentication_errors do
Application.fetch_env!(:neurow, :public_api_authentication)[:verbose_authentication_errors]
GenServer.call(__MODULE__, {:static_param, :public_api_verbose_authentication_errors})
end

def internal_api_issuer_jwks(issuer_name) do
GenServer.call(__MODULE__, {:internal_api_issuer_jwks, issuer_name})
end

def internal_api_audience do
Application.fetch_env!(:neurow, :internal_api_authentication)[:audience]
GenServer.call(__MODULE__, {:static_param, :internal_api_audience})
end

def internal_api_verbose_authentication_errors do
Application.fetch_env!(:neurow, :internal_api_authentication)[
:verbose_authentication_errors
]
GenServer.call(__MODULE__, {:static_param, :internal_api_verbose_authentication_errors})
end

def internal_api_jwt_max_lifetime do
GenServer.call(__MODULE__, {:static_param, :internal_api_jwt_max_lifetime})
end

def public_api_jwt_max_lifetime do
GenServer.call(__MODULE__, {:static_param, :public_api_jwt_max_lifetime})
end

def sse_timeout do
GenServer.call(__MODULE__, {:static_param, :sse_timeout})
end

def sse_keepalive do
GenServer.call(__MODULE__, {:static_param, :sse_keepalive})
end

@impl true
Expand All @@ -40,7 +54,23 @@ defmodule Neurow.Configuration do
},
internal_api: %{
issuer_jwks: build_issuer_jwks(:internal_api_authentication)
}
},
sse_keepalive: Application.fetch_env!(:neurow, :sse_keepalive),
sse_timeout: Application.fetch_env!(:neurow, :sse_timeout),
internal_api_jwt_max_lifetime:
Application.fetch_env!(:neurow, :internal_api_jwt_max_lifetime),
public_api_jwt_max_lifetime: Application.fetch_env!(:neurow, :public_api_jwt_max_lifetime),
internal_api_verbose_authentication_errors:
Application.fetch_env!(:neurow, :internal_api_authentication)[
:verbose_authentication_errors
],
public_api_verbose_authentication_errors:
Application.fetch_env!(:neurow, :public_api_authentication)[
:verbose_authentication_errors
],
internal_api_audience:
Application.fetch_env!(:neurow, :internal_api_authentication)[:audience],
public_api_audience: Application.fetch_env!(:neurow, :public_api_authentication)[:audience]
}}
end

Expand All @@ -54,6 +84,11 @@ defmodule Neurow.Configuration do
{:reply, state[:internal_api][:issuer_jwks][issuer_name], state}
end

@impl true
def handle_call({:static_param, key}, _from, state) do
{:reply, state[key], state}
end

defp build_issuer_jwks(api_authentication_scope) do
Application.fetch_env!(:neurow, api_authentication_scope)[:issuers]
|> Enum.map(fn {issuer_name, shared_secrets} ->
Expand Down
59 changes: 45 additions & 14 deletions neurow/lib/neurow/internal_api.ex
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ defmodule Neurow.InternalApi do
audience: &Neurow.Configuration.internal_api_audience/0,
verbose_authentication_errors:
&Neurow.Configuration.internal_api_verbose_authentication_errors/0,
max_lifetime: &Neurow.Configuration.internal_api_jwt_max_lifetime/0,
count_error: &Stats.inc_jwt_errors_internal/0,
exclude_path_prefixes: ["/ping", "/nodes", "/cluster_size_above"]
)

Expand Down Expand Up @@ -49,27 +51,56 @@ defmodule Neurow.InternalApi do
|> send_resp((cluster_size >= size && 200) || 404, "Cluster size: #{cluster_size}\n")
end

post "v1/publish" do
issuer = conn.assigns[:jwt_payload]["iss"]
post "/v1/publish" do
case extract_params(conn) do
{:ok, message, topic} ->
message_id = to_string(:os.system_time(:millisecond))

topic = "#{issuer}-#{conn.body_params["topic"]}"
message = conn.body_params["message"]
:ok =
Phoenix.PubSub.broadcast!(Neurow.PubSub, topic, {:pubsub_message, message_id, message})

{:ok, body, _conn} = Plug.Conn.read_body(conn)
message_id = to_string(:os.system_time(:millisecond))
Logger.debug("Message published on topic: #{topic}")
Stats.inc_msg_received()

:ok =
Phoenix.PubSub.broadcast!(Neurow.PubSub, topic, {:pubsub_message, message_id, message})
conn
|> put_resp_header("content-type", "text/html")
|> send_resp(200, "Published #{message} to #{topic}\n")

Logger.debug("Message published on topic: #{topic}")
Stats.inc_msg_received()

conn
|> put_resp_header("content-type", "text/html")
|> send_resp(200, "Published #{body} to #{topic}\n")
{:error, reason} ->
conn |> resp(:bad_request, reason)
end
end

match _ do
send_resp(conn, 404, "")
end

defp extract_params(conn) do
with(
{:ok, issuer} <- extract_issuer(conn),
{:ok, message} <- extract_param(conn, "message"),
{:ok, topic} <- extract_param(conn, "topic")
) do
full_topic = "#{issuer}-#{topic}"
{:ok, message, full_topic}
else
error -> error
end
end

defp extract_issuer(conn) do
case conn.assigns[:jwt_payload]["iss"] do
nil -> {:error, "JWT iss is nil"}
"" -> {:error, "JWT iss is empty"}
issuer -> {:ok, issuer}
end
end

defp extract_param(conn, key) do
case conn.body_params[key] do
nil -> {:error, "#{key} is nil"}
"" -> {:error, "#{key} is empty"}
output -> {:ok, output}
end
end
end
5 changes: 4 additions & 1 deletion neurow/lib/neurow/jwt_auth_plug.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ defmodule Neurow.JwtAuthPlug do
defstruct [
:jwk_provider,
:audience,
:max_lifetime,
:count_error,
allowed_algorithm: "HS256",
max_lifetime: 60 * 2,
verbose_authentication_errors: false,
exclude_path_prefixes: []
]
Expand Down Expand Up @@ -56,9 +57,11 @@ defmodule Neurow.JwtAuthPlug do
conn |> assign(:jwt_payload, payload.fields)
else
{:error, code, message} ->
options.count_error.()
conn |> forbidden(code, message, options)

_ ->
options.count_error.()
conn |> forbidden(:authentication_error, "Authentication error", options)
end

Expand Down
54 changes: 47 additions & 7 deletions neurow/lib/neurow/public_api.ex
Original file line number Diff line number Diff line change
Expand Up @@ -9,48 +9,88 @@ defmodule Neurow.PublicApi do
jwk_provider: &Neurow.Configuration.public_api_issuer_jwks/1,
audience: &Neurow.Configuration.public_api_audience/0,
verbose_authentication_errors:
&Neurow.Configuration.public_api_verbose_authentication_errors/0
&Neurow.Configuration.public_api_verbose_authentication_errors/0,
max_lifetime: &Neurow.Configuration.public_api_jwt_max_lifetime/0,
count_error: &Stats.inc_jwt_errors_public/0
)

plug(:match)
plug(:dispatch)

get "v1/subscribe" do
get "/v1/subscribe" do
case conn.assigns[:jwt_payload] do
%{"iss" => issuer, "sub" => sub} ->
topic = "#{issuer}-#{sub}"

timeout =
case conn.req_headers |> List.keyfind("x-sse-timeout", 0) do
nil -> Neurow.Configuration.sse_timeout()
{"x-sse-timeout", timeout} -> String.to_integer(timeout)
end

keep_alive =
case conn.req_headers |> List.keyfind("x-sse-keepalive", 0) do
nil -> Neurow.Configuration.sse_keepalive()
{"x-sse-keepalive", keepalive} -> String.to_integer(keepalive)
end

conn =
conn
|> put_resp_header("content-type", "text/event-stream")
|> put_resp_header("cache-control", "no-cache")
|> put_resp_header("connection", "close")
|> put_resp_header("access-control-allow-origin", "*")
|> put_resp_header("x-sse-server", to_string(node()))
|> put_resp_header("x-sse-timeout", to_string(timeout))
|> put_resp_header("x-sse-keepalive", to_string(keep_alive))

:ok = Phoenix.PubSub.subscribe(Neurow.PubSub, topic)

conn = send_chunked(conn, 200)

Logger.debug("Client subscribed to #{topic}")

conn |> loop(Application.fetch_env!(:neurow, :sse_timeout))
last_message = :os.system_time(:millisecond)
conn |> loop(timeout, keep_alive, last_message, last_message)
Logger.debug("Client disconnected from #{topic}")
conn

_ ->
conn |> resp(:bad_request, "expected JWT claims are missing")
conn |> resp(:bad_request, "Expected JWT claims are missing")
end
end

defp loop(conn, sse_timeout) do
defp loop(conn, sse_timeout, keep_alive, last_message, last_ping) do
receive do
{:pubsub_message, msg_id, msg} ->
{:ok, conn} = chunk(conn, "id: #{msg_id}\ndata: #{msg}\n\n")
Stats.inc_msg_published()
loop(conn, sse_timeout)
new_last_message = :os.system_time(:millisecond)
loop(conn, sse_timeout, keep_alive, new_last_message, new_last_message)
after
sse_timeout -> :timeout
1000 ->
now = :os.system_time(:millisecond)

cond do
# SSE Timeout
now - last_message > sse_timeout ->
Logger.debug("Client disconnected due to inactivity")
:timeout

# SSE Keep alive, send a ping
now - last_ping > keep_alive ->
chunk(conn, "event: ping\n\n")
loop(conn, sse_timeout, keep_alive, last_message, now)

# We need to stop
StopListener.close_connections?() ->
chunk(conn, "event: reconnect\n\n")
:close

# Nothing
true ->
loop(conn, sse_timeout, keep_alive, last_message, last_ping)
end
end
end

Expand Down
16 changes: 16 additions & 0 deletions neurow/lib/stats.ex
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,16 @@ defmodule Stats do
help: "SSE Messages"
)

Gauge.declare(
name: :jwt_errors,
labels: [:kind],
help: "JWT Errors"
)

Gauge.set([name: :current_connections], 0)
Gauge.set([name: :connections], 0)
Gauge.set([name: :jwt_errors, labels: [:public]], 0)
Gauge.set([name: :jwt_errors, labels: [:internal]], 0)
Gauge.set([name: :messages, labels: [:received]], 0)
Gauge.set([name: :messages, labels: [:published]], 0)
end
Expand All @@ -40,4 +48,12 @@ defmodule Stats do
def inc_msg_published() do
Gauge.inc(name: :messages, labels: [:published])
end

def inc_jwt_errors_public() do
Gauge.inc(name: :jwt_errors, labels: [:public])
end

def inc_jwt_errors_internal() do
Gauge.inc(name: :jwt_errors, labels: [:internal])
end
end
30 changes: 30 additions & 0 deletions neurow/lib/stop_listener.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
defmodule StopListener do
use GenServer
require Logger

def start_link(opts) do
GenServer.start_link(__MODULE__, opts, name: __MODULE__)
end

@impl true
def init(_) do
:ets.new(__MODULE__, [:set, :named_table, read_concurrency: true])
Process.flag(:trap_exit, true)
{:ok, %{shutdown_in_progress: false}}
end

def close_connections?() do
try do
:ets.lookup(__MODULE__, :close_connections?)
false
rescue
ArgumentError -> true
end
end

@impl GenServer
def terminate(_reason, _state) do
Logger.info("Graceful Shutdown occurring")
:ok
end
end
Loading

0 comments on commit a7f4c76

Please sign in to comment.