diff --git a/load_test/lib/load_test/user/sse.ex b/load_test/lib/load_test/user/sse.ex index abe0880..2e22a69 100644 --- a/load_test/lib/load_test/user/sse.ex +++ b/load_test/lib/load_test/user/sse.ex @@ -62,7 +62,11 @@ defmodule SseUser do {:ok, conn_pid} = :gun.open(String.to_atom(parsed_url.host), parsed_url.port, opts) {:ok, proto} = :gun.await_up(conn_pid) Logger.debug(fn -> "Connection established with proto #{inspect(proto)}" end) - stream_ref = :gun.get(conn_pid, parsed_url.path, headers) + + stream_ref = + if parsed_url.scheme == "ws", + do: :gun.ws_upgrade(conn_pid, parsed_url.path, headers), + else: :gun.get(conn_pid, parsed_url.path, headers) state = %SseState{ user_name: user_name, @@ -79,12 +83,37 @@ defmodule SseUser do wait_for_messages(state, conn_pid, stream_ref, expected_messages) end + defp process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg) do + msg = String.trim(msg) + Logger.debug(fn -> "#{header(state)} Received message: #{inspect(msg)}" end) + + if msg =~ "event: ping" do + wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages]) + else + if check_message(state, msg, first_message) == :error do + :ok = :gun.close(conn_pid) + raise("#{header(state)} Message check error") + end + + state = Map.put(state, :current_message, state.current_message + 1) + wait_for_messages(state, conn_pid, stream_ref, remaining_messages) + end + end + defp wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages]) do Logger.debug(fn -> "#{header(state)} Waiting for message: #{first_message}" end) result = :gun.await(conn_pid, stream_ref, state.sse_timeout) case result do + {:upgrade, _, _} -> + Logger.debug( + "#{header(state)} Connected to websocket, waiting: #{length(remaining_messages) + 1} messages, url #{state.url}" + ) + + state.start_publisher_callback.() + wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages]) + {:response, _, code, _} when code == 200 -> Logger.debug( "#{header(state)} Connected, waiting: #{length(remaining_messages) + 1} messages, url #{state.url}" @@ -99,21 +128,11 @@ defmodule SseUser do Stats.inc_msg_received_http_error() raise("#{header(state)} Error") + {:ws, {:text, msg}} -> + process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg) + {:data, _, msg} -> - msg = String.trim(msg) - Logger.debug(fn -> "#{header(state)} Received message: #{inspect(msg)}" end) - - if msg =~ "event: ping" do - wait_for_messages(state, conn_pid, stream_ref, [first_message | remaining_messages]) - else - if check_message(state, msg, first_message) == :error do - :ok = :gun.close(conn_pid) - raise("#{header(state)} Message check error") - end - - state = Map.put(state, :current_message, state.current_message + 1) - wait_for_messages(state, conn_pid, stream_ref, remaining_messages) - end + process_message(state, conn_pid, stream_ref, first_message, remaining_messages, msg) msg -> Logger.error("#{header(state)} Unexpected message #{inspect(msg)}") diff --git a/neurow/lib/neurow/application.ex b/neurow/lib/neurow/application.ex index a38c2ff..8f627d7 100644 --- a/neurow/lib/neurow/application.ex +++ b/neurow/lib/neurow/application.ex @@ -32,6 +32,17 @@ defmodule Neurow.Application do }) end + defp dispatcher do + [ + {:_, + [ + {Neurow.Configuration.public_api_context_path() <> "/v1/websocket", + Neurow.PublicApi.Websocket, []}, + {:_, Plug.Cowboy.Handler, {Neurow.PublicApi.Endpoint, []}} + ]} + ] + end + def start(%{ public_api_port: public_api_port, internal_api_port: internal_api_port, @@ -50,7 +61,8 @@ defmodule Neurow.Application do max_header_value_length: max_header_value_length, idle_timeout: :infinity ], - transport_options: [max_connections: :infinity] + transport_options: [max_connections: :infinity], + dispatch: dispatcher() ] {sse_http_scheme, public_api_http_config} = diff --git a/neurow/lib/neurow/public_api/websocket.ex b/neurow/lib/neurow/public_api/websocket.ex new file mode 100644 index 0000000..bdeb2eb --- /dev/null +++ b/neurow/lib/neurow/public_api/websocket.ex @@ -0,0 +1,121 @@ +defmodule Neurow.PublicApi.Websocket do + require Logger + @behaviour :cowboy_websocket + + @loop_duration 5000 + + def init(req, _opts) do + Logger.debug("Starting websocket connection") + + now_ms = :os.system_time(:millisecond) + + { + :cowboy_websocket, + req, + %{ + headers: req.headers, + last_ping_ms: now_ms, + last_message_ms: now_ms, + sse_timeout_ms: Neurow.Configuration.sse_timeout(), + keep_alive_ms: Neurow.Configuration.sse_keepalive(), + jwt_exp_s: :os.system_time(:second), + start_time: now_ms + }, + %{ + idle_timeout: 600_000, + max_frame_size: 1_000_000 + } + } + end + + defp authenticate(jwt_token, state) do + {_, payload} = JOSE.JWT.to_map(JOSE.JWT.peek_payload(jwt_token)) + issuer = payload["iss"] + topic = "#{issuer}-#{payload["sub"]}" + Logger.debug("Authenticated with topic: #{topic}") + :ok = Neurow.StopListener.subscribe() + :ok = Phoenix.PubSub.subscribe(Neurow.PubSub, topic) + state = Map.put(state, :jwt_exp_s, payload["exp"]) + state = Map.put(state, :issuer, issuer) + Neurow.Observability.MessageBrokerStats.inc_subscriptions(issuer) + Process.send_after(self(), :loop, @loop_duration) + {:ok, state} + end + + def websocket_init(state \\ %{}) do + case state.headers["authorization"] do + "Bearer " <> jwt_token -> + authenticate(jwt_token, state) + + _ -> + Logger.debug("No JWT token found in the headers") + Process.send_after(self(), :loop, @loop_duration) + {:ok, state} + end + end + + def websocket_handle({:text, frame}, state) do + Logger.debug("Received unexpected text frame: #{inspect(frame)}") + {:ok, state} + end + + def websocket_info(:loop, state) do + Process.send_after(self(), :loop, @loop_duration) + now_ms = :os.system_time(:millisecond) + + cond do + # Check JWT auth + jwt_expired?(now_ms, state.jwt_exp_s) -> + Logger.info("Client disconnected due to credentials expired") + + {[ + {:text, "event: credentials_expired\n"}, + {:close, 1000, "Credentials expired"} + ], state} + + # SSE timeout, send a timout event and stop the connection + sse_timed_out?(now_ms, state.last_message_ms, state.sse_timeout_ms) -> + Logger.info("Client disconnected due to inactivity") + {:stop, state} + + # SSE Keep alive, send a ping + sse_needs_keepalive?(now_ms, state.last_ping_ms, state.keep_alive_ms) -> + state = Map.put(state, :last_ping_ms, now_ms) + {:reply, {:text, "event: ping\n"}, state} + + true -> + {:ok, state} + end + end + + def websocket_info({:pubsub_message, message}, state) + when is_struct(message, Neurow.Broker.Message) do + Neurow.Observability.MessageBrokerStats.inc_message_sent(state.issuer) + state = Map.put(state, :last_message_ms, :os.system_time(:millisecond)) + + {:reply, + {:text, "id: #{message.timestamp}\nevent: #{message.event}\ndata: #{message.payload}\n"}, + state} + end + + defp sse_timed_out?(now_ms, last_message_ms, sse_timeout_ms), + do: now_ms - last_message_ms > sse_timeout_ms + + defp sse_needs_keepalive?(now_ms, last_ping_ms, keep_alive_ms), + do: now_ms - last_ping_ms > keep_alive_ms + + defp jwt_expired?(now_ms, jwt_exp_s), + do: jwt_exp_s * 1000 < now_ms + + def terminate(_reason, _req, state) do + if state.issuer do + duration_ms = + System.convert_time_unit(:os.system_time() - state.start_time, :native, :millisecond) + + Neurow.Observability.MessageBrokerStats.dec_subscriptions(state.issuer, duration_ms) + end + + Logger.debug("Terminating websocket connection") + :ok + end +end