Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a ws implementation #43

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 34 additions & 15 deletions load_test/lib/load_test/user/sse.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ defmodule SseUser do
{:ok, conn_pid} = :gun.open(String.to_atom(parsed_url.host), parsed_url.port, opts)
{:ok, proto} = :gun.await_up(conn_pid)
Logger.debug(fn -> "Connection established with proto #{inspect(proto)}" end)
stream_ref = :gun.get(conn_pid, parsed_url.path, headers)

stream_ref =
if parsed_url.scheme == "ws",
do: :gun.ws_upgrade(conn_pid, parsed_url.path, headers),
else: :gun.get(conn_pid, parsed_url.path, headers)

state = %SseState{
user_name: user_name,
Expand All @@ -79,12 +83,37 @@ defmodule SseUser do
wait_for_messages(state, conn_pid, stream_ref, expected_messages)
end

defp process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg) do
msg = String.trim(msg)
Logger.debug(fn -> "#{header(state)} Received message: #{inspect(msg)}" end)

if msg =~ "event: ping" do
wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages])
else
if check_message(state, msg, first_message) == :error do
:ok = :gun.close(conn_pid)
raise("#{header(state)} Message check error")
end

state = Map.put(state, :current_message, state.current_message + 1)
wait_for_messages(state, conn_pid, stream_ref, remaining_messages)
end
end

defp wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages]) do
Logger.debug(fn -> "#{header(state)} Waiting for message: #{first_message}" end)

result = :gun.await(conn_pid, stream_ref, state.sse_timeout)

case result do
{:upgrade, _, _} ->
Logger.debug(
"#{header(state)} Connected to websocket, waiting: #{length(remaining_messages) + 1} messages, url #{state.url}"
)

state.start_publisher_callback.()
wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages])

{:response, _, code, _} when code == 200 ->
Logger.debug(
"#{header(state)} Connected, waiting: #{length(remaining_messages) + 1} messages, url #{state.url}"
Expand All @@ -99,21 +128,11 @@ defmodule SseUser do
Stats.inc_msg_received_http_error()
raise("#{header(state)} Error")

{:ws, {:text, msg}} ->
process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg)

{:data, _, msg} ->
msg = String.trim(msg)
Logger.debug(fn -> "#{header(state)} Received message: #{inspect(msg)}" end)

if msg =~ "event: ping" do
wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages])
else
if check_message(state, msg, first_message) == :error do
:ok = :gun.close(conn_pid)
raise("#{header(state)} Message check error")
end

state = Map.put(state, :current_message, state.current_message + 1)
wait_for_messages(state, conn_pid, stream_ref, remaining_messages)
end
process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg)

msg ->
Logger.error("#{header(state)} Unexpected message #{inspect(msg)}")
Expand Down
14 changes: 13 additions & 1 deletion neurow/lib/neurow/application.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ defmodule Neurow.Application do
})
end

defp dispatcher do
[
{:_,
[
{Neurow.Configuration.public_api_context_path() <> "/v1/websocket",
Neurow.PublicApi.Websocket, []},
{:_, Plug.Cowboy.Handler, {Neurow.PublicApi.Endpoint, []}}
]}
]
end

def start(%{
public_api_port: public_api_port,
internal_api_port: internal_api_port,
Expand All @@ -50,7 +61,8 @@ defmodule Neurow.Application do
max_header_value_length: max_header_value_length,
idle_timeout: :infinity
],
transport_options: [max_connections: :infinity]
transport_options: [max_connections: :infinity],
dispatch: dispatcher()
]

{sse_http_scheme, public_api_http_config} =
Expand Down
121 changes: 121 additions & 0 deletions neurow/lib/neurow/public_api/websocket.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
defmodule Neurow.PublicApi.Websocket do
require Logger
@behaviour :cowboy_websocket

@loop_duration 5000

def init(req, _opts) do
Logger.debug("Starting websocket connection")

now_ms = :os.system_time(:millisecond)

{
:cowboy_websocket,
req,
%{
headers: req.headers,
last_ping_ms: now_ms,
last_message_ms: now_ms,
sse_timeout_ms: Neurow.Configuration.sse_timeout(),
keep_alive_ms: Neurow.Configuration.sse_keepalive(),
jwt_exp_s: :os.system_time(:second),
start_time: now_ms
},
%{
idle_timeout: 600_000,
max_frame_size: 1_000_000
}
}
end

defp authenticate(jwt_token, state) do
{_, payload} = JOSE.JWT.to_map(JOSE.JWT.peek_payload(jwt_token))
issuer = payload["iss"]
topic = "#{issuer}-#{payload["sub"]}"
Logger.debug("Authenticated with topic: #{topic}")
:ok = Neurow.StopListener.subscribe()
:ok = Phoenix.PubSub.subscribe(Neurow.PubSub, topic)
state = Map.put(state, :jwt_exp_s, payload["exp"])
state = Map.put(state, :issuer, issuer)
Neurow.Observability.MessageBrokerStats.inc_subscriptions(issuer)
Process.send_after(self(), :loop, @loop_duration)
{:ok, state}
end

def websocket_init(state \\ %{}) do
case state.headers["authorization"] do
"Bearer " <> jwt_token ->
authenticate(jwt_token, state)

_ ->
Logger.debug("No JWT token found in the headers")
Process.send_after(self(), :loop, @loop_duration)
{:ok, state}
end
end

def websocket_handle({:text, frame}, state) do
Logger.debug("Received unexpected text frame: #{inspect(frame)}")
{:ok, state}
end

def websocket_info(:loop, state) do
Process.send_after(self(), :loop, @loop_duration)
now_ms = :os.system_time(:millisecond)

cond do
# Check JWT auth
jwt_expired?(now_ms, state.jwt_exp_s) ->
Logger.info("Client disconnected due to credentials expired")

{[
{:text, "event: credentials_expired\n"},
{:close, 1000, "Credentials expired"}
], state}

# SSE timeout, send a timout event and stop the connection
sse_timed_out?(now_ms, state.last_message_ms, state.sse_timeout_ms) ->
Logger.info("Client disconnected due to inactivity")
{:stop, state}

# SSE Keep alive, send a ping
sse_needs_keepalive?(now_ms, state.last_ping_ms, state.keep_alive_ms) ->
state = Map.put(state, :last_ping_ms, now_ms)
{:reply, {:text, "event: ping\n"}, state}

true ->
{:ok, state}
end
end

def websocket_info({:pubsub_message, message}, state)
when is_struct(message, Neurow.Broker.Message) do
Neurow.Observability.MessageBrokerStats.inc_message_sent(state.issuer)
state = Map.put(state, :last_message_ms, :os.system_time(:millisecond))

{:reply,
{:text, "id: #{message.timestamp}\nevent: #{message.event}\ndata: #{message.payload}\n"},
state}
end

defp sse_timed_out?(now_ms, last_message_ms, sse_timeout_ms),
do: now_ms - last_message_ms > sse_timeout_ms

defp sse_needs_keepalive?(now_ms, last_ping_ms, keep_alive_ms),
do: now_ms - last_ping_ms > keep_alive_ms

defp jwt_expired?(now_ms, jwt_exp_s),
do: jwt_exp_s * 1000 < now_ms

def terminate(_reason, _req, state) do
if state.issuer do
duration_ms =
System.convert_time_unit(:os.system_time() - state.start_time, :native, :millisecond)

Neurow.Observability.MessageBrokerStats.dec_subscriptions(state.issuer, duration_ms)
end

Logger.debug("Terminating websocket connection")
:ok
end
end
Loading