Source

amall / src / amall_http.ml

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
open Am_All;
open Cd_All; open Strings.Latin1;

value max_uri_len = 4096;
value max_header_len = 4096;
value max_headers_size = 10240;
value default_uri_scheme = "http";

(**********)

type request_method = [= `GET | `POST | `HEAD ]
;

type transfer_encoding =
  [= `Identity
  |  `Chunked
  |  `Gzip
  |  `Other of string
  ]
;

type connection_header_val =
  [= `Close
  |  `Other of string
  ]
;

open Uri_type
;

type rq_headers =
  { connection : list connection_header_val
  ; content_length : option int
  ; transfer_length : option int
  ; transfer_encoding : list transfer_encoding
  ; rqh_host : option (host_kind * string * option int)
  ; rq_all : list (string * string);
  }
;

value empty_rq_headers : rq_headers =
  { connection = []
  ; content_length = None
  ; transfer_encoding = []
  ; transfer_length = None
  ; rqh_host = None
  ; rq_all = []
  }
;


type request =
  { rq_method : request_method
  ; rq_uri : uri
  ; rq_request_uri__ : uri  (* uri from "GET ... HTTP/*" line, not useful *)
  ; rq_version : (int * int)
  ; rq_headers : rq_headers
  }
;

exception Bad_request of string
;

type rs_headers =
  { rs_all : list (string * string)
  }
;

type file_size = int64
;

type rs_body =
  [ No_body
  | Body_string of string
  | File_contents of string and file_size
  ]
;

type response =
  { rs_status_code : int
  ; rs_reason_phrase : string
  ; rs_headers : rs_headers
  ; rs_body : rs_body
  }
;

value request_method_texts =
  [ ("GET", `GET)
  ; ("POST", `POST)
  ; ("HEAD", `HEAD)
  ]
;

value max_request_method_len =
  List.fold_left
    (fun acc (text, _meaning) -> max acc (String.length_bytes text))
    0
    request_method_texts
;


value (make_headers : list (string * string) -> rq_headers) lst =
  let connection = ref []
  and content_length = ref None
  and transfer_length = ref None
  and transfer_encoding = ref []
  and host = ref None
  in
  let () = List.iter
    (fun (hk, hv) ->
       let _add r = r.val := [hv :: r.val]
       and set_opt_int r =
         try r.val := some & int_of_string hv
         with [Failure _ -> ()]
       and addmap f r = r.val := List.append (f hv) r.val in
       let parse_connection hv =
         (* todo: tokenize *)
         [if hv = "close" then `Close else `Other hv]
       and parse_t_e hv =
         (* todo: tokenize *)
         match hv with
         [ "identity" -> []
         | "chunked" -> [`Chunked]
         | "gzip" -> [`Gzip]
         | x -> [`Other x]
         ]
       in
       match String.lowercase hk with
       [ "connection" -> addmap parse_connection connection
       | "content-length" -> set_opt_int content_length
       | "transfer-length" -> set_opt_int transfer_length
       | "transfer-encoding" -> addmap parse_t_e transfer_encoding
       | "host" ->
           match (host.val, Uri.parse_host_portopt hv) with
           [ (None, ((Some _) as some_h)) -> host.val := some_h
           | _ -> ()
           ]
       | _ -> ()
       ]
    )
    lst
  in
    { connection = connection.val
    ; content_length = content_length.val
    ; transfer_length = transfer_length.val
    ; transfer_encoding = transfer_encoding.val
    ; rqh_host = host.val
    ; rq_all = lst
    }
;


(************************************************************)

value is_spaces = fun
  [ '\x20' | '\x0A' | '\x0D' | '\x09' -> True
  | _ -> False
  ]
;

value is_line_term = fun
  [ '\x0A' | '\x0D' -> True
  | _ -> False
  ]
;

value is_whitespace = fun
  [ '\x20' | '\x09' -> True
  | _ -> False
  ]
;

(************************************************************)

value request_has_message_body rq =
  let h = rq.rq_headers in
     h.transfer_encoding <> []
  || h.content_length <> None
  || h.transfer_length <> None
;

value some_default_uri_scheme = Some default_uri_scheme
;

(************************************************************)



module Make
  (IO : Amall_types.IO_Type)
  (I : It_type.IT with type It_IO.m 'a = IO.m 'a)
=
struct

open I.Ops;


(* добавляет заголовки, которые нужно вычислить на основании
   других частей ответа.
 *)

value response_headers resp lst =
  let nocase = String.eq_nocase_latin1 in
  let headers_with_lengths =
    let set_cl len_str =
        List.Assoc.replace
          ~keq:nocase
          "Content-length"
          len_str
          lst
    in
    match resp.rs_body with
    [ No_body ->
        IO.return & List.Assoc.remove ~keq:nocase "Content-length" lst
    | Body_string s ->
        s |> String.length |> string_of_int |> set_cl |> IO.return
    | File_contents _fn sz ->
        sz |> Int64.to_string |> set_cl |> IO.return
    ]
  in
    headers_with_lengths
;


value string_of_header (k, v) =
  let ch c =
    if String.contains k c
    then invalid_arg
      "http response header: header name should not contain char %C"
      c
    else
      ()
  in
    ( ch '\r'; ch '\n'; ch ':'; ch '\x20'; ch '\x00'
    ; sprintf "%s: %s\r\n" k v
    )
;


value string_of_response_headers rs : IO.m string =
  let err msg = IO.error
    (Invalid_argument (sprintf "http response: %s" msg)) in
  let code = rs.rs_status_code in
  if code < 100 || code >= 1000
  then err & sprintf "status code must be 3-digit (now: %i)" code
  else
  let reas = rs.rs_reason_phrase in
  if String.contains reas '\n' || String.contains reas '\r'
  then err "reason phrase must not contain CR or LF"
  else
  response_headers rs rs.rs_headers.rs_all >>% fun processed_headers ->
  IO.return &
  sprintf "HTTP/1.1 %i %s\r\n%s\r\n"
    code
    reas
    (String.concat "" &
     List.map string_of_header &
     processed_headers
    )
;


value read_the_string str err =
  let charlist = String.explode str in
  I.heads charlist >>= fun matched ->
  if matched = String.length str
  then I.return ()
  else I.throw_err err
;


(************************************************************)

value (it_eof : I.iteratee 'el1 'a -> I.iteratee 'el2 'a) it =
  I.lift (I.run it)
;

value (it_eof_ignore : I.iteratee 'el1 'a -> I.iteratee 'el2 unit) it =
  I.lift
    ( IO.catch
        (fun () ->
           I.run it >>% fun a ->
           let () = ignore a in
           IO.return ()
        )
        (fun _ -> IO.return ())
    )
;

value list_map_all func lst =
  inner [] lst
  where rec inner rev_acc lst =
    match lst with
    [ [] -> I.return & List.rev rev_acc
    | [h :: t] ->
        func h >>= fun fh -> inner [fh :: rev_acc] t
    ]
;

(************************************************************)

value (it_http :
  (request -> I.iteratee char 'a) ->
  I.iteratee char (I.iteratee char (request * 'a))
)
process_request =
  let fail r = I.throw_err & Bad_request r in
  let read_the_char c =
    (I.catch
      (fun () -> I.mapI some I.head)
      (fun [ End_of_file -> I.return None
           | x -> I.throw_err x])
    )
    >>= fun
    [ None -> fail & sprintf "expected %C, got eof" c
    | Some c' when c = c' -> I.return ()
    | Some c' -> fail & sprintf "expected %C, got %C" c c'
    ]
  in
  let read_component ~limit ~name ~break_pred =
    I.break_limit ~pred:break_pred ~limit >>= fun (status, s) ->
    match status with
    [ `Hit_limit -> fail & sprintf "bad %s (too big)" name
    | `Hit_eof -> fail & sprintf "eof reading %s" name
    | `Found -> I.return & I.Subarray.to_string s
    ]
  in
  let read_line_terminators =
    ( I.heads ['\r'; '\n'] >>= fun n ->
      if n = 0
      then I.heads ['\n']
      else I.return n
    ) >>= fun n ->
    I.return (n <> 0)
  in
  let read_method =
    read_component
      ~name:"method"
      ~limit:(max_request_method_len + 1)
      ~break_pred:is_spaces
    >>= fun meth_txt ->
    match
      List.Assoc.get_opt
        ~keq:String.eq_nocase_latin1
        meth_txt
        request_method_texts
    with
    [ None -> fail "method not supported"
    | Some meth -> I.return meth
    ]
  and read_uri =
    read_component
      ~name:"URI"
      ~limit:max_uri_len
      ~break_pred:is_spaces
    >>= fun uri_txt ->
    match Uri.parse uri_txt with
    [ None -> fail "bad uri"
    | Some u -> I.return u
    ]
  and read_version =
    let read_uint name =
      let max_digits = 9 in
      let not_digit c = (c > '9' || c < '0') in
      (* ignore_zeroes >>= fun () -> *)
      read_component
        ~limit:(max_digits + 1)
        ~name
        ~break_pred:not_digit
      >>= fun uint_txt ->
      try I.return & int_of_string uint_txt
      with [ _ -> fail "internal error" ]
    in
    read_the_string "HTTP/" (Bad_request "expected \"HTTP/\" string")
    >>= fun () ->
    read_uint "http major version" >>= fun ver_maj ->
    read_the_char '.' >>= fun () ->
    read_uint "http minor version" >>= fun ver_min ->
    I.return (ver_maj, ver_min)
  and read_eol =
    read_line_terminators >>= fun t ->
    if t
    then I.return ()
    else fail "end-of-line not found"
  and read_headers =
    let rec read_headers acc =
      read_component ~limit:max_header_len ~name:"header"
        ~break_pred:is_line_term >>= fun header_line ->
      read_line_terminators >>= fun t ->
      match (t, header_line) with
      [ (True, "") -> I.return & List.rev acc
      | (True, _) ->
          let _ () = dbg "header_line = %S" header_line in
          (* process the header *)
          let (first_spaces, header_line, _last_spaces) =
            String.trim_count is_whitespace header_line in
          if first_spaces > 0
          then
            match acc with
            [ [] -> fail "first header starts with whitespace"
            | [last :: others] ->
                read_headers [(last ^ " " ^ header_line) :: others]
            ]
          else
            read_headers [header_line :: acc]
      | (False, _) -> fail "premature end of headers"
      ]
    in
      read_headers []
  in
    read_method >>= fun meth ->
    read_the_char '\x20' >>= fun () ->
    read_uri >>= fun request_uri__ ->
    read_the_char '\x20' >>= fun () ->
    read_version >>= fun version ->
    read_eol >>= fun () ->
    I.limit max_headers_size read_headers >>= fun rh_it ->
    match rh_it with
    [ I.IE_cont (Some e) _k -> I.throw_err e
    | I.IE_cont None _ ->
        it_eof_ignore rh_it >>= fun () ->
        fail & sprintf "headers too large (max %i bytes allowed)"
          max_headers_size
    | I.IE_done headers -> I.return headers
    ] >>= fun header_lines ->
    list_map_all
      (  fun line ->
           let (hname, sep, hval_sp) =
             String.split_by_first ( (=) ':' ) line in
           if sep = "" then fail "header without ':'"
           else
             let hval = String.trim is_whitespace hval_sp in
             I.return (hname, hval)
      )
      header_lines
    >>= fun header_lines ->
    let rq_headers = make_headers header_lines in
    let rq_uri =
      match (request_uri__.authority, rq_headers.rqh_host) with
      [ (None, Some (host_kind, host, port_opt)) ->
          (* let () = dbg "it_http: host=%S port=%s" host
            (match port_opt with [None -> "-" | Some i -> string_of_int i]) in
          *)
           { (request_uri__) with
             authority = Some
               { host_kind = host_kind
               ; host = host
               ; port = port_opt
               ; userinfo = None
               }
           }
      | _ -> request_uri__
      ]
    in
    let rq_uri =
      match rq_uri.scheme with
      [ None -> { (rq_uri) with scheme = some_default_uri_scheme }
      | Some _ -> rq_uri
      ]
    in
    let () = dbg "it_http: rq_uri: %s" (Uri.dump_uri rq_uri) in
    I.return
    { rq_method = meth
    ; rq_request_uri__ = request_uri__
    ; rq_version = version
    ; rq_headers = rq_headers
    ; rq_uri = rq_uri
    }
    >>= fun request ->
    let ret = I.mapI
      (fun res ->
        let () = dbg "got result from body iteratee" in
        (request, res)
      )
    in
    let user_it = process_request request in
    if not & request_has_message_body request
    then
      let () = dbg "request has no body" in
      I.return (ret user_it)
    else
      let () = dbg "request has body" in
      (
        (* todo: chunked t.e. *)
        let h = request.rq_headers in
        if h.transfer_encoding <> []
        then
          fail "non-identity transfer encodings are not implemented"
        else
          match h.content_length with
          [ Some len -> I.return & `Content_length len
	  | None ->
              if List.exists ((=) `Close) h.connection
              then
                I.return `Till_eof
              else
                fail "411 Length required"
          ]
      ) >>= fun bounds ->
      match bounds with
      [ `Till_eof ->
          let () = dbg "bounds: till eof" in
          I.return & ret user_it
      | `Content_length len ->
          let () = dbg "bounds: Content_length %i" len in
          I.take len user_it >>= fun it ->
          I.return & ret it
      ]
;


value it_post_vars : I.iteratee char (list (string * string))
 =
  I.gather_to_string >>= fun s ->
  let vars =
    s
    >> String.split_exact ((=) '&')
    >> List.map
         (fun binding ->
            let (var_key, eq_sign, var_val) =
              String.split_by_first ((=) '=') binding
            in
              if eq_sign = ""
              then
                (* решение "хоть как-то" *)
                (var_key, "")
              else
                (var_key, String.urldecode var_val)
         )
  in
    I.return vars
;


value (request_with_post_vars
: request -> I.iteratee char (request * list (string * string))
)
request
 =
  it_post_vars >>= fun vars ->
  let () =
    List.iter
      (fun (k, v) ->
         dbg "body: %s = %S" k v
      )
      vars
  in
  I.return (request, vars)
;



(***************************************************************)


value output_file_buffer_size = ref 16384
;

value output_body_file outch fn sz =
  IO.open_in fn >>% fun inch ->
  let finally () = IO.close_in inch in
  IO.catch
    (fun () ->
       let bufsz = Int64.to_int &
         min (Int64.of_int output_file_buffer_size.val) sz in
       let buf = String.make bufsz '\x00' in
       read_loop sz
       where rec read_loop left =
         if left = 0L
         then
           IO.return ()
         else
           let to_read = Int64.to_int & min (Int64.of_int bufsz) left in
           IO.read_into inch buf 0 to_read >>% fun have_read ->
           if have_read = 0
           then IO.error (Failure "file is shorter than expected")
           else
             ( write_loop 0 have_read >>% fun () ->
               read_loop (Int64.sub left (Int64.of_int have_read))
             )
             where rec write_loop ofs to_write =
               let () = assert (to_write >= 0) in
               if to_write = 0
               then
                 IO.return ()
               else
                 IO.write_from outch buf ofs to_write >>% fun written ->
                 write_loop (ofs + written) (to_write - written)
    )
    (fun e ->
       finally () >>% fun () ->
       IO.error e
    )
;


value output_body outch rs_body =
  match rs_body with
  [ No_body -> IO.return ()
  | Body_string s -> IO.write outch s
  | File_contents fn sz ->
      output_body_file outch fn sz
  ]
;


value rec output_response ~is_head outch rs : IO.m unit =
  IO.catch
    (fun () ->
      string_of_response_headers rs >>% fun hstr -> IO.return & `Ok hstr
    )
    (fun e ->
        let msg = Printexc.to_string e in
        IO.return & `Error
          { rs_status_code = 500
          ; rs_reason_phrase = "Internal server error"
          ; rs_headers = { rs_all = [] }
          ; rs_body = Body_string msg
          }
    )
  >>% fun res_hstr ->
  match res_hstr with
  [ `Error rs -> output_response ~is_head outch rs
  | `Ok hstr ->
      IO.write outch hstr >>% fun () ->
      (if is_head
       then IO.return ()
       else output_body outch rs.rs_body
       ) >>% fun () ->
      IO.flush outch
  ]
;


end;


module It_pure = Iteratees.Make(Pure_IO);

value get_basic_auth
 : rq_headers -> option (string * string)  (* Some (user, password) *)
 = fun h ->
     match List.Assoc.get_opt ~keq:String.eq "Authorization" h.rq_all with
     [ None -> None
     | Some hval ->
         match String.split is_whitespace hval with
         [ [meth; uue] when String.eq_nocase_latin1 meth "basic" ->
             let res_user_pass =
               It_pure.(
                 Pure_IO.catch
                   (fun () ->
                      Pure_IO.bind
                      (fun it -> run it)
                      (enum_string uue &
                       joinI &
                       base64_decode gather_to_string
                      )
                   )
                   Pure_IO.error
               ) in
             match res_user_pass with
             [ `Error _ -> None
             | `Ok user_pass ->
                 let (user, colon, pass) =
                   String.split_by_first ( (=) ':' ) user_pass in
                 if colon = ":"
                 then Some (user, pass)
                 else None
             ]
         | _ -> None
         ]
     ]
;


value parse_content_type
 : string ->
   [= `Ok of (string * string * list (string * string)) | `Error of string ]
 = fun str ->
     Urilex.content_type (Lexing.from_string str)
;