Source

screentime-monitor / rule_client.ml

Full commit
open Core.Std
open Async.Std
open Client_common

let show_rules =
  Command.async_basic
    ~summary:"Retrieve the screentime rules"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
       Rpc.Rpc.dispatch_exn Protocol.get_rules conn username
       >>| fun rules ->
       printf "%s\n" (With_format.format rules)
     ))

let edit_rules =
  Command.async_basic
    ~summary:"Set the screentime rule"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
       Rpc.Rpc.dispatch_exn Protocol.get_rules conn username
       >>= fun rule ->
       let tempfile = Filename.temp_file "rule" ".scm" in
       Writer.save tempfile ~contents:(With_format.format rule)
       >>= fun () ->
       let editor =
         match Sys.getenv "EDITOR" with None -> "emacs" | Some x -> x
       in
       edit_file (module Rule.List) ~editor ~tempfile
       >>= function
       | None -> return ()
       | Some rules ->
         printf "Set the rules for %s? (y/N):" (Username.to_string username);
         Reader.read_line (Lazy.force Reader.stdin)
         >>= fun response ->
         let upload =
           match response with
           | `Eof -> false
           | `Ok s ->
             match s |> String.lowercase |> String.strip with
             | "y" | "yes" -> true
             | _ -> false
         in
         if not upload then
           (printf "Not setting rules.\n"; return ())
         else
           Rpc.Rpc.dispatch_exn Protocol.set_rules conn (username,rules)
           >>| fun () ->
           printf "Rules set\n"
     ))

let rule_violations =
  Command.async_basic
    ~summary:"Get any rule violations"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
       Rpc.Rpc.dispatch_exn Protocol.todays_violations conn username
       >>| fun violations ->
       violations
       |> <:sexp_of<(Rule.t * Time.Span.t * Rule_store.Status.t) list>>
       |> Sexp.to_string_hum
       |> printf "%s\n"
     ))

let violation_pipe user conn =
  let (r,w) = Pipe.create () in
  let stop = Pipe.closed r in
  Clock.every' (sec 1.) ~stop (fun () ->
    Rpc.Rpc.dispatch_exn Protocol.todays_violations conn user
    >>= fun violations -> 
    if Pipe.is_closed w then Deferred.unit
    else Pipe.write w violations);
  r

module Full_violation = struct
  type t =
    { rule: Rule.t
    ; status: Rule_store.Status.t
    ; exceeded_by: Time.Span.t
    ; reportable: bool
    }

  let print_list ts =
    let module Ascii_table = Textutils.Std.Ascii_table in
    let cols = 
      Ascii_table.(
        [ Column.create "rule"  (fun t -> t.rule.name |> Rule.Name.to_string)
        ; Column.create "exceeded by" (fun t -> Time.Span.to_string t.exceeded_by)
        ; Column.create "status" (fun t ->
          match t.status with Acked -> "Ack" | Unacked -> "-")
        ; Column.create "reportable" (fun t -> Bool.to_string t.reportable)
        ]
      )
    in
    printf "%s"
      (Ascii_table.to_string ~display:Ascii_table.Display.line cols ts)

  let create_unreportable (rule,exceeded_by,status) =
    { rule; exceeded_by; status; reportable = false }

  let create ~prev ~curr =
    let to_map vs =
      vs
      |> List.map ~f:(fun (rule,x,y) -> (rule,(x,y)))
      |> Rule.Map.of_alist_exn
    in
    Map.merge (to_map prev) (to_map curr)
      ~f:(fun ~key:rule diff ->
        match diff with
        | `Left _  -> None
        | `Right (exceeded_by,status) ->
          Some {rule;exceeded_by;status;reportable=false}
        | `Both ((prev_span,_),(curr_span,status)) ->
          let reportable = 
            match (status : Rule_store.Status.t) with
            | Acked ->  false
            | Unacked -> Time.Span.(curr_span > prev_span)
          in
          let exceeded_by = curr_span in
          Some {rule;status;exceeded_by;reportable}
      )
    |> Map.data

  let get_pipe user conn =
    let violations = violation_pipe user conn in
    let (r,w) = Pipe.create () in
    let rec loop prev =
      Pipe.read violations
      >>= function
      | `Eof -> Pipe.close w; Deferred.unit
      | `Ok curr -> 
        if Pipe.is_closed w then Deferred.unit
        else (
          Pipe.write w (create ~prev ~curr)
          >>= fun () -> loop curr
        )
    in
    don't_wait_for begin
      Pipe.read violations 
      >>= function
      | `Eof -> Pipe.close w; Deferred.unit
      | `Ok first -> 
        Pipe.write w (List.map ~f:create_unreportable first)
        >>= fun () ->
        loop first
    end;
    r
end

let throttle limit ~if_not_run =
  let last_time = ref Time.epoch in
  stage (fun f ->
    let now = Time.now () in
    let time_since =Time.diff now !last_time in
    if Time.Span.(time_since > limit) then (
      last_time := now;
      f ()
    ) else
      if_not_run ()
  )


let monitor_violations user conn ~stop =
  let r = Full_violation.get_pipe user conn in
  let alert_throttle =
    unstage (throttle (sec 30.) ~if_not_run:(fun () -> Deferred.unit))
  in
  let spinner = Lazy.force spinner in
  upon stop (fun () -> Pipe.close_read r);
  Async_shell.run_full "which" [Sys.argv.(0)]
  >>= fun my_path ->
  let my_path = String.strip my_path in
  Pipe.iter r ~f:(fun violations ->
    clear_screen ()
    >>= fun () ->
    print_endline (spinner ());
    Full_violation.print_list violations;
    Deferred.List.iter violations
      ~f:(fun {Full_violation. rule;exceeded_by;status=_;reportable} ->
        if not reportable then Deferred.unit
        else
          alert_throttle (fun () ->
            printf ".\n";
            (* Command to execute when popup is clicked *)
            let execute =
              sprintf "%s rules ack '%s'"
                my_path
                (Rule.Name.to_string rule.name)
            in
            print_endline execute;
            Notify.spawn
              ~title:"Screentime exceeded"
              ~sound:`Default
              ~execute
              (sprintf "'%s' exceeded by %s"
                 (Rule.Name.to_string rule.Rule.name)
                 (Time.Span.to_string exceeded_by))
          )))

let monitor_violations =
  Command.async_basic
    ~summary:"Monitor for violations"
    (shared_flags ())
    (fun user () ->
       force clear_string
       >>= fun cstring -> 
       retry
         ~on_retry:(fun () -> printf "%sretrying..\n" cstring)
         ~on_error:(fun err ->
           printf "%sFailed!  Will retry.\n\n%s\n" cstring
             (Exn.sexp_of_t err |>  Sexp.to_string_hum))
         (fun () ->
            try_with (fun () ->
              setup_conn user (fun user conn ->
                monitor_violations user conn ~stop:(force on_term_signal))))
       |> Deferred.ignore
    )

let ack =
  Command.async_basic
    ~summary:"Acknowledge a rule, and prevent futher warnings"
    Command.Spec.(
      shared_flags ()
      +> anon ("rule" %: Arg_type.create Rule.Name.of_string)
    )
    (fun user rule () -> setup_conn user (fun user conn ->
       Rpc.Rpc.dispatch_exn Protocol.acknowledge conn (user,rule)
     ))

let unack =
  Command.async_basic
    ~summary:"Unacknowledge a rule, and prevent futher warnings"
    Command.Spec.(
      shared_flags ()
      +> anon ("rule" %: Arg_type.create Rule.Name.of_string)
    )
    (fun user rule () -> setup_conn user (fun user conn ->
       Rpc.Rpc.dispatch_exn Protocol.unacknowledge conn (user,rule)
     ))

let command =
  Command.group
    ~summary:"Tools for interacting with screentime rules"
    [ "show"       , show_rules
    ; "edit"       , edit_rules
    ; "violations" , rule_violations
    ; "monitor"    , monitor_violations
    ; "ack"        , ack
    ; "unack"      , unack
    ]