Source

dumbstreaming / dumbstreaming_io.ml

open Ds_types;

module Make (IO : IO_Type)
 =
  struct

    value ( >>= ) = IO.bind_rev
    ;

    exception Dumbstreaming of string
    ;

    value (write1 : IO.output_channel -> string -> IO.m unit) outch str =
      let pre = Printf.sprintf "%i\n" (String.length str)
      and post = "\n\n" in
      IO.write outch pre >>= fun () ->
      IO.write outch str >>= fun () ->
      IO.write outch post >>= fun () ->
      IO.flush outch
    ;

    value (write : IO.output_channel -> list string -> IO.m unit)
      outch
      strs
     =
      let count = List.length strs in
      let lens = List.map String.length strs in
      let totalsize =
        List.fold_left Int64.add 0L (List.map Int64.of_int lens) in
      let lens_txt_list = List.map string_of_int lens in
      let lens_txt = String.concat "\x20" lens_txt_list in
      let pre = Printf.sprintf "%Li %i %s\x0A" totalsize count lens_txt in
      IO.write outch pre >>= fun () ->
      let rec loop lst =
        match lst with
        [ [] -> IO.return ()
        | [h :: t] ->
            IO.write outch h >>= fun () ->
            IO.write outch "\x0A" >>= fun () ->
            loop t
        ]
      in
      loop strs >>= fun () ->
      IO.write outch "\x0A" >>= fun () ->
      IO.flush outch
    ;

    value close_out outch =
      IO.write outch "\n" >>= fun () ->
      IO.flush outch >>= fun () ->
      IO.close_out outch
    ;

    value max_len_digits = 10
    ;

    value err msg = IO.error (Dumbstreaming msg)
    ;

    value sprintf fmt = Printf.sprintf fmt
    ;

    value read_char inch =
      let str = String.make 1 '\x00' in
      IO.read_into inch str 0 1 >>= fun has_read ->
      if has_read = 0
      then err "end of channel"
      else if has_read = 1
      then IO.return str.[0]
      else err "read_char: bad 'has_read'"
    ;

    value read_len_or_eos inch =
      inner ~anydigit:False max_len_digits 0
      where rec inner ~anydigit left acc =
        if left < 0
        then
          err (sprintf "len > %i" max_len_digits)
        else
          read_char inch >>= fun c ->
          if c = '\n'
          then
            IO.return (
              if anydigit
              then Some acc
              else None
              )
          else
            if c >= '0' && c <= '9'
            then
              let d = (Char.code c) - (Char.code '0') in
              let new_acc = 10 * acc + d in
              inner ~anydigit:True (left - 1) new_acc
            else
              err (sprintf "excepted decimal number (length)")
    ;

    value read_into_exact inch buf ofs len =
      loop ~ofs ~len
      where rec loop ~ofs ~len =
        let () = assert (len >= 0) in
        if len = 0
        then
          IO.return ()
        else
          IO.read_into inch buf ofs len >>= fun has_read ->
          if has_read = 0
          then
            err "unexpected eof"
          else
            loop ~ofs:(ofs + has_read) ~len:(len - has_read)
    ;

    value read_the_char inch c =
      read_char inch >>= fun r ->
      if r = c
      then IO.return ()
      else err (sprintf "excepted %C, found %C" c r)
    ;

    value read_msg_post inch =
      read_the_char inch '\n' >>= fun () ->
      read_the_char inch '\n'
    ;

    value read inch =
      read_len_or_eos inch >>= fun len_or_eos ->
      match len_or_eos with
      [ None ->
          IO.close_in inch >>= fun () ->
          IO.return None
      | Some len ->
          if len > Sys.max_string_length
          then err "string is longer than Sys.max_string_length"
          else
            let r = String.make len '\x00' in
            read_into_exact inch r 0 len >>= fun () ->
            read_msg_post inch >>= fun () ->
            IO.return (Some r)
      ]
    ;

  end
;



module Buffer_output (* : IO_Type *)
 =
  struct
    type m +'a = (unit -> 'a);
    value return x = fun () -> x;
    value bind_rev ma amb = fun () -> (amb (ma ()) ());
    value error = raise;
    type input_channel = unit;
    value read_into _ _ _ _ = failwith "read_into";
    value close_in _ = failwith "close_in";
    type output_channel = Buffer.t;
    value write b s = fun () -> Buffer.add_string b s;
    value flush _b = return ();
    value close_out _b = return ();
  end
;


module Ds_buf = Make(Buffer_output);


value to_string lst =
  let b = Buffer.create 20 in
  let f = Ds_buf.write b lst in
  let () = f () in
  Buffer.contents b
;