Commits

Oliver Gu committed 4fccde0

Added more robust version checks of TWS responses

Comments (0)

Files changed (2)

   let val_type = Val_type.create to_string of_string
 end
 
+module Header = struct
+  type 'a t = {
+    tag : 'a;
+    version : int;
+  } with sexp
+  let create ~tag ~version = { tag; version; }
+end
+
 module Ibx_error = struct
   type t =
   | Connection_closed
   | Read_error of Sexp.t
   | Parse_error of Sexp.t
   | Tws_error of string
-  | Unknown_response_handler of Query_id.t * Recv_tag.t
+  | Unknown_response_handler of Query_id.t * Recv_tag.t * [ `Version of int ]
   | Version_failure of int * Recv_tag.t
-  | Unpickler_mismatch of Sexp.t * Recv_tag.t list
+  | Unpickler_mismatch of Sexp.t * Recv_tag.t Header.t list
   | Uncaught_exn of Sexp.t
   with sexp
   exception Ibx of t with sexp
        | `Die of Ibx_error.t ] Deferred.t
 
   type t = {
-    tag  : Recv_tag.t;
-    run  : handler;
+    tag     : Recv_tag.t;
+    version : int;
+    run     : handler;
   }
 
-  let create ~tag ~run = { tag; run }
+  let create ~tag ~version ~run = { tag; version; run }
 end
 
 module type Connection_internal = sig
   val cancel_streaming :
     ?query:Query.t
     -> t
-    -> tags:Recv_tag.t list
+    -> recv_header:Recv_tag.t Header.t list
     -> query_id:Query_id.t
     -> (unit, [ `Closed ]) Result.t
 
   type t =
     { writer           : Writer.t;
       reader           : string Pipe.Reader.t;
-      open_queries     : (Query_id.t * Recv_tag.t, response_handler) Hashtbl.t;
+      open_queries     : (Query_id.t * Recv_tag.t * version, response_handler) Hashtbl.t;
       default_query_id : Query_id.t;
       next_order_id    : Order_id.t Ivar.t;
       account_code     : Account_code.t Ivar.t;
       logfun           : logfun option;
       extend_error     : Error.t -> unit;
     }
+  and version = int
   and response_handler = Response_handler.handler
   and logfun = [ `Send of Query.t | `Recv of Response.t ] -> unit
 
-  let init_handler ?id t ~tag ~unpickler ~action ~f =
+  let init_handler ?id t ~tag ~version ~unpickler ~action ~f =
     Hashtbl.replace t.open_queries
-      ~key:(Option.value id ~default:t.default_query_id, tag)
+      ~key:(Option.value id ~default:t.default_query_id, tag, version)
       ~data:(fun response ->
         match response.Response.data with
         | Error err ->
     in
     init_handler t
       ~tag:Recv_tag.Tws_error
+      ~version:2
       ~unpickler:Tws_error.unpickler
       ~action:`Keep
       ~f:(fun e -> extend_status ("TWS " ^ Tws_error.to_string_hum e));
     init_handler t
       ~tag:Recv_tag.Next_order_id
+      ~version:1
       ~unpickler:Next_order_id.unpickler
       ~action:`Remove
       ~f:(Ivar.fill t.next_order_id);
     init_handler t
       ~tag:Recv_tag.Managed_accounts
+      ~version:1
       ~unpickler:Account_code.unpickler
       ~action:`Remove
       ~f:(Ivar.fill_if_empty t.account_code);
     init_handler t
       ~tag:Recv_tag.Execution_report
+      ~version:9
       ~unpickler:Execution_report.unpickler
       ~action:`Keep
       ~f:extend_execution_report;
     init_handler t
       ~tag:Recv_tag.Commission_report
+      ~version:1
       ~unpickler:Commission_report.unpickler
       ~action:`Keep
       ~f:extend_commission_report;
       let id = Option.value ~default:t.default_query_id query.Query.id in
       List.iter handlers ~f:(fun h ->
         let tag = h.Response_handler.tag in
+        let version = h.Response_handler.version in
         let run = h.Response_handler.run in
-        Hashtbl.replace t.open_queries ~key:(id, tag) ~data:run);
+        Hashtbl.replace t.open_queries ~key:(id, tag, version) ~data:run);
       Ok ()
 
-  let cancel_streaming ?query t ~tags ~query_id =
+  let cancel_streaming ?query t ~recv_header ~query_id =
     match writer t with
     | Error `Closed as x -> x
     | Ok writer ->
       Option.iter query ~f:(send_query ?logfun:t.logfun writer);
-      List.iter tags ~f:(fun tag ->
-        match Hashtbl.find t.open_queries (query_id, tag) with
+      List.iter recv_header ~f:(fun header ->
+        let tag = header.Header.tag in
+        let version = header.Header.version in
+        match Hashtbl.find t.open_queries (query_id, tag, version) with
         | None -> ()
         | Some response_handler ->
           don't_wait_for (Deferred.ignore (
             response_handler
               { Response.
                 tag;
-                version  = 1;
+                version;
                 query_id = Some query_id;
-                data     = Ok `Cancel;
+                data = Ok `Cancel;
               }))
       );
       Ok ()
   let handle_response t response =
     let id = Option.value ~default:t.default_query_id response.Response.query_id in
     let tag = response.Response.tag in
-    match Hashtbl.find t.open_queries (id, tag) with
+    let version = response.Response.version in
+    let key = (id, tag, version) in
+    match Hashtbl.find t.open_queries key with
     | None ->
-      return (`Stop (Ibx_error.Unknown_response_handler (id, tag)))
+      return (`Stop (Ibx_error.Unknown_response_handler (id, tag, `Version version)))
     | Some f  ->
       f response >>| fun action ->
       begin
         match action with
         | `Remove ->
-          Hashtbl.remove t.open_queries (id, tag);
+          Hashtbl.remove t.open_queries key;
           `Continue
         | `Keep ->
           `Continue
         | `Replace handler ->
-          Hashtbl.replace t.open_queries ~key:(id, tag) ~data:handler;
+          Hashtbl.replace t.open_queries ~key ~data:handler;
           `Continue
         | `Die err ->
           `Stop err
           | exn -> Ibx_error.Uncaught_exn (Exn.sexp_of_t exn)
         in
         Hashtbl.iter t.open_queries
-          ~f:(fun ~key:(id, tag) ~data:response_handler ->
+          ~f:(fun ~key:(id, tag, version) ~data:response_handler ->
             don't_wait_for (Deferred.ignore (
               response_handler
                 { Response.
                   tag;
-                  version  = 1;
+                  version;
                   query_id = Some id;
-                  data     = Error error
+                  data = Error error
                 })));
         t.extend_error (Ibx_error.to_error error));
     Scheduler.within ~monitor loop
     end >>| Ibx_result.or_error
 end
 
-module Header = struct
-  type 'a t = {
-    tag : 'a;
-    version : int;
-  }
-
-  let create ~tag ~version = { tag; version; }
-  let tag t = t.tag
-end
-
 module Request = struct
   type ('query, 'response) t =
     { send_header  : Send_tag.t Header.t;
     let handler =
       Response_handler.create
         ~tag:t.recv_header.Header.tag
+        ~version:t.recv_header.Header.version
         ~run:(fun response ->
-          if response.Response.version = t.recv_header.Header.version then
-            match response.Response.data with
-            | Error err as x ->
-              (* If this handler died before, the ivar is already filled. *)
-              Ivar.fill_if_empty ivar x;
-              return (`Die err)
-            | Ok `Cancel ->
-              assert false (* Non-streaming requests are not cancelable. *)
-            | Ok (`Response data) ->
-              begin
-                match of_tws t.tws_response data with
-                | Error err as x ->
-                  Ivar.fill ivar x;
-                  return (`Die err)
-                | Ok _response as x ->
-                  Ivar.fill ivar x;
-                  return `Remove
-              end
-          else begin
-            let err = Ibx_error.Version_failure (
-              response.Response.version,
-              t.recv_header.Header.tag
-            ) in
-            Ivar.fill ivar (Error err);
+          match response.Response.data with
+          | Error err as x ->
+            (* If this handler died before, the ivar is already filled. *)
+            Ivar.fill_if_empty ivar x;
             return (`Die err)
-          end)
+          | Ok `Cancel ->
+            assert false (* Non-streaming requests are not cancelable. *)
+          | Ok (`Response data) ->
+            begin
+              match of_tws t.tws_response data with
+              | Error err as x ->
+                Ivar.fill ivar x;
+                return (`Die err)
+              | Ok _response as x ->
+                Ivar.fill ivar x;
+                return `Remove
+            end)
     in
     begin
       match Connection.dispatch con ~handlers:[handler] query with
     let error_handler =
       Response_handler.create
         ~tag:Recv_tag.Tws_error
+        ~version:1
         ~run:(fun response ->
           match response.Response.data with
           | Error err ->
         List.map skip_headers ~f:(fun header ->
           Response_handler.create
             ~tag:header.Header.tag
+            ~version:header.Header.version
             ~run:(fun response ->
-              (* Note: Skip handlers don't check the message version. *)
               match response.Response.data with
               | Error err -> return (`Die err)
               | Ok `Cancel -> return `Remove
         ~f:(fun header unpickler ->
           Response_handler.create
             ~tag:header.Header.tag
+            ~version:header.Header.version
             ~run:(fun response ->
-              if response.Response.version = header.Header.version then
-                let update pipe_w response =
-                  match response.Response.data with
-                  | Error err ->
-                    Pipe.close pipe_w;
-                    return (`Die err)
-                  | Ok `Cancel ->
-                    Pipe.close pipe_w;
-                    return `Remove
-                  | Ok (`Response data) ->
-                    begin
-                      match of_tws unpickler data with
-                      | Error err ->
-                        Pipe.close pipe_w;
-                        return (`Die err)
-                      | Ok response ->
-                        if not (Pipe.is_closed pipe_w) then begin
-                          (* We guard this write call to protect us against
-                             incoming messages after a cancelation, causing
-                             a write call to a closed pipe. *)
-                          don't_wait_for (Pipe.write pipe_w response)
-                        end;
-                        return `Keep
-                    end
-                in
+              let update pipe_w response =
                 match response.Response.data with
-                | Error err as x ->
+                | Error err ->
                   Pipe.close pipe_w;
-                  Ivar.fill_if_empty ivar x;
                   return (`Die err)
                 | Ok `Cancel ->
                   Pipe.close pipe_w;
                 | Ok (`Response data) ->
                   begin
                     match of_tws unpickler data with
-                    | Error err as x ->
+                    | Error err ->
                       Pipe.close pipe_w;
-                      Ivar.fill_if_empty ivar x;
                       return (`Die err)
                     | Ok response ->
-                      don't_wait_for (Pipe.write pipe_w response);
-                      (* We fill the ivar only in the first iteration. *)
-                      Ivar.fill_if_empty ivar (Ok (pipe_r, query_id));
-                      return (`Replace (update pipe_w))
+                      if not (Pipe.is_closed pipe_w) then begin
+                        (* We guard this write call to protect us against
+                           incoming messages after a cancelation, causing
+                           a write call to a closed pipe. *)
+                        don't_wait_for (Pipe.write pipe_w response)
+                      end;
+                      return `Keep
                   end
-              else begin
-                let err = Ibx_error.Version_failure (
-                  response.Response.version,
-                  header.Header.tag
-                ) in
-                Ivar.fill_if_empty ivar (Error err);
+              in
+              match response.Response.data with
+              | Error err as x ->
+                Pipe.close pipe_w;
+                Ivar.fill_if_empty ivar x;
                 return (`Die err)
-              end)))
+              | Ok `Cancel ->
+                Pipe.close pipe_w;
+                return `Remove
+              | Ok (`Response data) ->
+                begin
+                  match of_tws unpickler data with
+                  | Error err as x ->
+                    Pipe.close pipe_w;
+                    Ivar.fill_if_empty ivar x;
+                    return (`Die err)
+                  | Ok response ->
+                    don't_wait_for (Pipe.write pipe_w response);
+                    (* We fill the ivar only in the first iteration. *)
+                    Ivar.fill_if_empty ivar (Ok (pipe_r, query_id));
+                    return (`Replace (update pipe_w))
+                end)))
     in
     begin
       match data_handler_result with
       | Error exn ->
-        let tags = List.map t.recv_header ~f:Header.tag in
         don't_wait_for (Connection.close con);
-        Ivar.fill ivar (Error (Ibx_error.Unpickler_mismatch (Exn.sexp_of_t exn, tags)))
+        let err = Ibx_error.Unpickler_mismatch (Exn.sexp_of_t exn, t.recv_header) in
+        Ivar.fill ivar (Error err)
       | Ok data_handlers ->
         let handlers = error_handler :: (skip_handlers @ data_handlers) in
         match Connection.dispatch con ~handlers query with
-        | Ok () -> Connection.closed con >>> fun () -> Pipe.close pipe_w
+        | Ok () -> ()
         | Error `Closed -> Ivar.fill ivar (Error Ibx_error.Connection_closed)
     end;
     Ivar.read ivar >>| Ibx_result.or_error
   let dispatch_exn t con query = dispatch t con query >>| Or_error.ok_exn
 
   let cancel t con query_id =
-    let tags = Recv_tag.Tws_error :: (List.map t.recv_header ~f:Header.tag) in
+    let recv_header =
+      { Header.
+        tag = Recv_tag.Tws_error;
+        version = 1
+      } :: t.recv_header
+    in
     let result =
       match t.canc_header with
-      | None -> Connection.cancel_streaming con ~tags ~query_id
+      | None -> Connection.cancel_streaming con ~recv_header ~query_id
       | Some header ->
         let query =
           { Query.
             data    = "";
           }
         in
-        Connection.cancel_streaming con ~tags ~query_id ~query
+        Connection.cancel_streaming con ~recv_header ~query_id ~query
     in
     ignore (result : (unit, [ `Closed ]) Result.t)
 end
   ~send_header:(Ib.Header.create ~tag:S.Submit_order ~version:39)
   ~canc_header:(Ib.Header.create ~tag:S.Cancel_order ~version:1)
   ~recv_header:[Ib.Header.create ~tag:R.Order_status ~version:6]
-  ~skip_header:[Ib.Header.create ~tag:R.Open_order ~version:1]
+  ~skip_header:[Ib.Header.create ~tag:R.Open_order ~version:30]
   ~tws_query:Query.Submit_order.pickler
   ~tws_response:[Response.Order_status.unpickler]
   ()