diff --git a/.merlin b/.merlin index cca8de9..510ea36 100644 --- a/.merlin +++ b/.merlin @@ -14,4 +14,4 @@ PKG conduit.lwt-unix PKG nocrypto PKG core PKG async -PKG cohttp.async \ No newline at end of file +PKG cohttp.async diff --git a/_tags b/_tags index 883349c..6d4dc5c 100644 --- a/_tags +++ b/_tags @@ -18,6 +18,13 @@ true: bin_annot, debug, safe_string package(nocrypto.lwt), \ package(conduit) +: package(lwt), \ + package(lwt.ppx), \ + package(uri), \ + package(cohttp.lwt), \ + package(containers), \ + package(conduit) + : package(async), \ package(uri), \ package(cohttp.async), \ @@ -43,4 +50,15 @@ true: bin_annot, debug, safe_string package(uri), \ package(cohttp.async), \ package(nocrypto.unix), \ - package(containers) \ No newline at end of file + package(containers) + +: \ + package(core), \ + package(containers), \ + package(cohttp), \ + package(cohttp.lwt), \ + package(conduit), \ + package(lwt), \ + package(lwt.ppx), \ + package(nocrypto), \ + package(ppx_deriving) diff --git a/lib/websocket_cohttp_lwt.ml b/lib/websocket_cohttp_lwt.ml new file mode 100644 index 0000000..5e726c0 --- /dev/null +++ b/lib/websocket_cohttp_lwt.ml @@ -0,0 +1,66 @@ +module C = Cohttp +module Lwt_IO = Websocket.IO(Cohttp_lwt_unix_io) + +open Lwt + +let send_frames stream oc = + let buf = Buffer.create 128 in + let send_frame fr = + Buffer.clear buf; + Lwt_IO.write_frame_to_buf ~masked:false buf fr; + Lwt_io.write oc @@ Buffer.contents buf + in + Lwt_stream.iter_s send_frame stream + ;; + +let read_frames icoc handler_fn = + let read_frame () = + let rf = Lwt_IO.make_read_frame ~masked:false icoc () in + match%lwt rf with + | `Ok frame -> Lwt.return frame + | `Error msg -> Lwt.fail_with msg + in + while%lwt true do + read_frame () >>= Lwt.wrap1 handler_fn + done + ;; + +let upgrade_connection request conn incoming_handler = + let headers = Cohttp.Request.headers request in + let key = CCOpt.get_exn @@ Cohttp.Header.get headers "sec-websocket-key" in + let hash = key ^ Websocket.websocket_uuid |> Websocket.b64_encoded_sha1sum in + let response_headers = + Cohttp.Header.of_list + ["Upgrade", "websocket" + ;"Connection", "Upgrade" + ;"Sec-WebSocket-Accept", hash] + in + let resp = + Cohttp.Response.make + ~status:`Switching_protocols + ~encoding:Cohttp.Transfer.Unknown + ~headers:response_headers + ~flush:true + () + in + + let frames_out_stream, frames_out_fn = Lwt_stream.create () in + + let body_stream, stream_push = Lwt_stream.create () in + let _ = + let open Conduit_lwt_unix in + match conn with + | TCP (tcp : tcp_flow) -> + let oc = Lwt_io.of_fd ~mode:Lwt_io.output tcp.fd in + let ic = Lwt_io.of_fd ~mode:Lwt_io.input tcp.fd in + Lwt.join [ + (* input: data from the client is read from the input channel + * of the tcp connection; pass it to handler function *) + read_frames (ic, oc) incoming_handler; + (* output: data for the client is written to the output + * channel of the tcp connection *) + send_frames frames_out_stream oc; + ] + | _ -> Lwt.fail_with "expected TCP Websocket connection" + in + Lwt.return (resp, Cohttp_lwt_body.of_stream body_stream, frames_out_fn) diff --git a/lib/websocket_cohttp_lwt.mli b/lib/websocket_cohttp_lwt.mli new file mode 100644 index 0000000..dfa6394 --- /dev/null +++ b/lib/websocket_cohttp_lwt.mli @@ -0,0 +1,5 @@ +val upgrade_connection: + Cohttp.Request.t -> + Conduit_lwt_unix.flow -> + (Websocket.Frame.t -> unit) -> + (Cohttp.Response.t * Cohttp_lwt_body.t * (Websocket.Frame.t option -> unit)) Lwt.t diff --git a/pkg/build.ml b/pkg/build.ml index 89a81e5..d6e22ca 100644 --- a/pkg/build.ml +++ b/pkg/build.ml @@ -14,4 +14,5 @@ let () = Pkg.bin ~cond:lwt ~auto:true "tests/wscat"; Pkg.bin ~cond:async ~auto:true "tests/wscat_async"; Pkg.bin ~cond:lwt ~auto:true "tests/reynir"; + Pkg.bin ~cond:lwt ~auto:true "tests/upgrade_connection"; ] diff --git a/pkg/topkg.ml b/pkg/topkg.ml index 9e2ba47..03320f0 100644 --- a/pkg/topkg.ml +++ b/pkg/topkg.ml @@ -273,7 +273,7 @@ module Pkg : Pkg = struct let describe pkg ~builder mvs = let mvs = List.sort compare (List.flatten mvs) in let btool, bdir = match builder with - | `OCamlbuild -> "ocamlbuild -use-ocamlfind -classic-display", "_build" + | `OCamlbuild -> "ocamlbuild -tag thread -use-ocamlfind -classic-display", "_build" | `Other (btool, bdir) -> btool, bdir in match Topkg.cmd with diff --git a/tests/upgrade_connection.ml b/tests/upgrade_connection.ml new file mode 100644 index 0000000..fb04aab --- /dev/null +++ b/tests/upgrade_connection.ml @@ -0,0 +1,102 @@ +open Lwt +open Core.Std + +let handler + (conn : Conduit_lwt_unix.flow * Cohttp.Connection.t) + (req : Cohttp_lwt_unix.Request.t) + (body : Cohttp_lwt_body.t) = + Lwt_io.eprintf + "[CONN] %s\n%!" (Cohttp.Connection.to_string @@ snd conn) + >>= fun _ -> + let uri = Cohttp.Request.uri req in + match Uri.path uri with + | "/" -> + Lwt_io.eprintf "[PATH] /\n%!" + >>= fun () -> + Cohttp_lwt_unix.Server.respond_string + ~status:`OK + ~body: {| + + + + + + + +
+ + + |} + () + | "/ws" -> + Lwt_io.eprintf "[PATH] /ws\n%!" + >>= fun () -> + Cohttp_lwt_body.drain_body body + >>= fun () -> + Websocket_cohttp_lwt.upgrade_connection req (fst conn) ( + fun f -> + match f.Websocket.Frame.opcode with + | Websocket.Frame.Opcode.Close -> + Printf.eprintf "[RECV] CLOSE\n%!" + | _ -> + Printf.eprintf "[RECV] %s\n%!" f.Websocket.Frame.content + ) + >>= fun (resp, body, frames_out_fn) -> + (* send a message to the client every second *) + let _ = + let num_ref = ref 10 in + let rec go () = + if !num_ref > 0 then + let msg = Printf.sprintf "-> Ping %d" !num_ref in + Lwt_io.eprintf "[SEND] %s\n%!" msg + >>= fun () -> + Lwt.wrap1 frames_out_fn @@ + Some ( + Websocket.Frame.of_bytes @@ + BytesLabels.of_string @@ + msg + ) + >>= fun () -> + Lwt.return (num_ref := !num_ref - 1) + >>= fun () -> + Lwt_unix.sleep 1. + >>= go + else + Lwt_io.eprintf "[INFO] Test done\n%!" + >>= Lwt.return + in + go () + in + Lwt.return (resp, (body :> Cohttp_lwt_body.t)) + | _ -> + Lwt_io.eprintf "[PATH] Catch-all\n%!" + >>= fun () -> + Cohttp_lwt_unix.Server.respond_string + ~status:`Not_found + ~body:(Sexp.to_string_hum (Cohttp.Request.sexp_of_t req)) + () + +let start_server host port () = + let conn_closed (ch,_) = + Printf.eprintf "[SERV] connection %s closed\n%!" + (Sexplib.Sexp.to_string_hum (Conduit_lwt_unix.sexp_of_flow ch)) + in + Lwt_io.eprintf "[SERV] Listening for HTTP on port %d\n%!" port + >>= fun _ -> + Cohttp_lwt_unix.Server.create + ~mode:(`TCP (`Port port)) + (Cohttp_lwt_unix.Server.make ~callback:handler ~conn_closed ()) + +(* main *) +let () = + Lwt_main.run (start_server "localhost" 7777 ())