Commits

Oliver Gu committed 39398bb

Refactored client

Comments (0)

Files changed (2)

         commission_reports = Tail.create ();
       }
 
+  exception Not_connected_yet with sexp
+
+  exception Eof_from_client with sexp
+
   exception Server_version_too_small of int * [ `Min of int ] with sexp
 
   let ignore_errors f = don't_wait_for (Monitor.try_with f >>| ignore)
   let connect t =
     let module C = Client_msg in
     let module E = Client_msg.Control in
-    match Or_error.try_with (fun () -> Socket.create Socket.Type.tcp) with
-    | Error err ->
+    match Result.try_with (fun () -> Socket.create Socket.Type.tcp) with
+    | Error exn ->
+      Tail.extend t.messages (C.Error (Error.of_exn exn));
       t.con <- `Disconnected;
-      return (Error err)
+      return ()
     | Ok s ->
-      let close_socket err =
+      let close_socket exn =
         t.con <- `Disconnected;
-        Tail.extend t.messages (C.Error err);
+        Tail.extend t.messages (C.Error (Error.of_exn exn));
         ignore_errors (fun () -> Unix.close (Socket.fd s));
       in
       Tail.extend t.messages (C.Control (E.Connecting (
         `Client_version t.client_version,
         Host_and_port.create ~host:t.remote_host ~port:t.remote_port
       )));
-      t.con <- `Connecting (fun () ->
-        close_socket (Ibx_error.to_error Ibx_error.Connection_closed));
+      t.con <- `Connecting (fun () -> close_socket Not_connected_yet);
       Monitor.try_with ~name:"connect socket" (fun () ->
         Unix.Inet_addr.of_string_or_getbyname t.remote_host
         >>= fun inet_addr ->
         Socket.connect s address)
       >>= function
       | Error exn ->
-        let err = Error.of_exn (Monitor.extract_exn exn) in
-        close_socket err;
-        return (Error err)
+        close_socket (Monitor.extract_exn exn);
+        return ()
       | Ok s ->
         let fd = Socket.fd s in
         Connection.create
           (Reader.create fd)
           (Writer.create fd)
         >>= fun con ->
-        let close_connection err =
+        let close_connection exn =
           t.con <- `Disconnected;
-          Tail.extend t.messages (C.Error err);
+          Tail.extend t.messages (C.Error (Error.of_exn exn));
           Connection.close con
         in
         Monitor.try_with ~name:"try connect" (fun () ->
           | Ok handshake_result ->
             begin match handshake_result with
             | H.Eof ->
-              Error.raise (Ibx_error.to_error Ibx_error.Unexpected_eof)
+              raise Eof_from_client
             | H.Version_failure version ->
               raise (Server_version_too_small (version, `Min Config.server_version))
             | H.Server_header (`Version version, conn_time, account_code) ->
                 E.Connected (`Server_version version, conn_time)));
             end)
         >>= function
-        | Error exn ->
-          let err = Error.of_exn (Monitor.extract_exn exn) in
-          close_connection err >>| fun () ->
-          Error err
-        | Ok () as x -> return x
+        | Error exn -> close_connection (Monitor.extract_exn exn)
+        | Ok () -> return ()
 
   let messages t = Tail.collect t.messages
   let execution_reports  t = Tail.collect t.execution_reports
   let commission_reports t = Tail.collect t.commission_reports
 
+  let client_id       t = t.client_id
+  let server_version  t = t.server_version
+  let connection_time t = t.connection_time
+  let account_code    t = t.account_code
+
+  let is_connected t =
+    match t.con with
+    | `Disconnected
+    | `Connecting _ -> false
+    | `Connected _  -> true
+
+  let state t =
+    match t.con with
+    | `Disconnected -> `Disconnected
+    | `Connecting _ -> `Connecting
+    | `Connected  _ -> `Connected
+
   let set_server_log_level t ~level =
     match t.con with
     | `Disconnected
       ~on_handler_error
       handler =
     let module C = Client_msg in
-    Monitor.try_with (fun () ->
-      create ?enable_logging ?client_id ~host ~port ()
-      >>= fun t ->
-      if t.enable_logging then begin
-        Stream.iter (messages t) ~f:(fun clt_msg ->
-          match clt_msg with
-          | C.Control x ->
-            Log.Global.sexp ~level:`Info  x C.Control.sexp_of_t
-          | C.Status x ->
-            Log.Global.sexp ~level:`Info  x String.sexp_of_t
-          | C.Error e ->
-            Log.Global.sexp ~level:`Error e Error.sexp_of_t;
-            Error.raise e)
-      end else begin
-        Stream.iter (messages t) ~f:(fun clt_msg ->
-          match clt_msg with
-          | C.Control _
-          | C.Status  _ -> ()
-          | C.Error e -> Error.raise e)
-      end;
-      Monitor.protect (fun () ->
-        connect t
-        >>= function
-        | Error e -> Error.raise e
-        | Ok () -> handler t
-      ) ~finally:(fun () -> disconnect t)
-    ) >>| fun result ->
-    match result with
-    | Ok () -> ()
-    | Error e ->
-      let e = Monitor.extract_exn e in
+    let handle_error e =
+      match on_handler_error with
+      | `Ignore -> ()
+      | `Raise  -> Error.raise e
+      | `Call f -> f e
+    in
+    create ?enable_logging ?client_id ~host ~port ()
+    >>= fun t ->
+    Stream.iter (messages t) ~f:(fun clt_msg ->
       begin
-        match on_handler_error with
-        | `Ignore -> ()
-        | `Raise  -> raise e
-        | `Call f -> f (Error.of_exn e)
-      end
-
-  let client_id       t = t.client_id
-  let server_version  t = t.server_version
-  let connection_time t = t.connection_time
-  let account_code    t = t.account_code
-
-  let is_connected t =
-    match t.con with
-    | `Disconnected
-    | `Connecting _ -> false
-    | `Connected _  -> true
-
-  let state t = match t.con with
-    | `Disconnected -> `Disconnected
-    | `Connecting _ -> `Connecting
-    | `Connected  _ -> `Connected
+        match clt_msg, t.enable_logging with
+        | C.Control x, true ->
+          Log.Global.sexp ~level:`Info x <:sexp_of< C.Control.t >>
+        | C.Status x, true ->
+          Log.Global.sexp ~level:`Info x <:sexp_of< string >>
+        | C.Error e, true ->
+          Log.Global.sexp ~level:`Error e <:sexp_of< Error.t >>;
+          handle_error e
+        | C.Error e, false ->
+          handle_error e
+        | _ -> ()
+      end);
+    connect t >>= fun () ->
+    match state t with
+    | `Connected -> handler t >>= fun () -> disconnect t
+    | _ -> return ()
 
   let dispatch_request t req query =
     match t.con with
     | `Disconnected
-    | `Connecting _  ->
-      return (Error (Ibx_error.to_error Ibx_error.Connection_closed))
+    | `Connecting _  -> return (Or_error.of_exn Not_connected_yet)
     | `Connected con -> Request.dispatch req con query
 
   let dispatch_streaming_request t req query =
     match t.con with
     | `Disconnected
-    | `Connecting _  ->
-      return (Error (Ibx_error.to_error Ibx_error.Connection_closed))
+    | `Connecting _  -> return (Or_error.of_exn Not_connected_yet)
     | `Connected con -> Streaming_request.dispatch req con query
 
   let cancel_streaming_request t req id =
       Pipe.read_at_most reader ~num_values:1
       >>| fun read_result ->
       Exn.protectx read_result ~f:(function
-      | `Eof -> Error (Ibx_error.to_error Ibx_error.Unexpected_eof)
+      | `Eof -> Or_error.of_exn Eof_from_client
       | `Ok result -> Ok (Queue.dequeue_exn result)
       ) ~finally:(fun _ -> cancel_streaming_request t req id)
 end
 
   (** [connect t] initiates a connection and returns a deferred that becomes
       determined when the connection is established. *)
-  val connect : t -> unit Or_error.t Deferred.t
+  val connect : t -> unit Deferred.t
 
   val disconnect : t -> unit Deferred.t