Source

screentime-monitor / rule_client.ml

open Core.Std
open Async.Std
open Client_common

let show_rules =
  Command.async_basic
    ~summary:"Retrieve the screentime rules"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
         Rpc.Rpc.dispatch_exn Protocol.get_rules conn username
         >>| fun rules ->
         printf "%s\n" (With_format.format rules)
       ))

let edit_rules =
  Command.async_basic
    ~summary:"Set the screentime rule"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
         Rpc.Rpc.dispatch_exn Protocol.get_rules conn username
         >>= fun rule ->
         let tempfile = Filename.temp_file "rule" ".scm" in
         Writer.save tempfile ~contents:(With_format.format rule)
         >>= fun () ->
         let editor =
           match Sys.getenv "EDITOR" with None -> "emacs" | Some x -> x
         in
         edit_file (module Rule.List) ~editor ~tempfile
         >>= function
         | None -> return ()
         | Some rules ->
           printf "Set the rules for %s? (y/N):" (Username.to_string username);
           Reader.read_line (Lazy.force Reader.stdin)
           >>= fun response ->
           let upload =
             match response with
             | `Eof -> false
             | `Ok s ->
               match s |> String.lowercase |> String.strip with
               | "y" | "yes" -> true
               | _ -> false
           in
           if not upload then
             (printf "Not setting rules.\n"; return ())
           else
             Rpc.Rpc.dispatch_exn Protocol.set_rules conn (username,rules)
             >>| fun () ->
             printf "Rules set\n"
       ))

let rule_violations =
  Command.async_basic
    ~summary:"Get any rule violations"
    (shared_flags ())
    (fun username () -> setup_conn username (fun username conn ->
         Rpc.Rpc.dispatch_exn Protocol.todays_violations conn username
         >>| fun violations ->
         violations
         |> <:sexp_of<(Rule.t * Time.Span.t * Rule_store.Status.t) list>>
         |> Sexp.to_string_hum
         |> printf "%s\n"
       ))

let violation_pipe user conn =
  let (r,w) = Pipe.create () in
  let stop = Pipe.closed r in
  Clock.every' (sec 1.) ~stop (fun () ->
      Rpc.Rpc.dispatch_exn Protocol.todays_violations conn user
      >>= fun violations -> 
      if Pipe.is_closed w then Deferred.unit
      else Pipe.write w violations);
  r

module Full_violation = struct
  type t =
    { rule: Rule.t
    ; status: Rule_store.Status.t
    ; exceeded_by: Time.Span.t
    ; reportable: bool
    }

  let print_list ts =
    let module Ascii_table = Textutils.Std.Ascii_table in
    let cols = 
      Ascii_table.(
        [ Column.create "rule"  (fun t -> t.rule.name |> Rule.Name.to_string)
        ; Column.create "exceeded by" (fun t -> Time.Span.to_string t.exceeded_by)
        ; Column.create "status" (fun t ->
            match t.status with Acked -> "Ack" | Unacked -> "-")
        ; Column.create "reportable" (fun t -> Bool.to_string t.reportable)
        ]
      )
    in
    printf "%s"
      (Ascii_table.to_string ~display:Ascii_table.Display.line cols ts)

  let create_unreportable (rule,exceeded_by,status) =
    { rule; exceeded_by; status; reportable = false }

  let create ~prev ~curr =
    let to_map vs =
      vs
      |> List.map ~f:(fun (rule,x,y) -> (rule,(x,y)))
      |> Rule.Map.of_alist_exn
    in
    Map.merge (to_map prev) (to_map curr)
      ~f:(fun ~key:rule diff ->
          match diff with
          | `Left _  -> None
          | `Right (exceeded_by,status) ->
            Some {rule;exceeded_by;status;reportable=false}
          | `Both ((prev_span,_),(curr_span,status)) ->
            let reportable = 
              match (status : Rule_store.Status.t) with
              | Acked ->  false
              | Unacked -> Time.Span.(curr_span > prev_span)
            in
            let exceeded_by = curr_span in
            Some {rule;status;exceeded_by;reportable}
        )
    |> Map.data

  let get_pipe user conn =
    let violations = violation_pipe user conn in
    let (r,w) = Pipe.create () in
    let rec loop prev =
      Pipe.read violations
      >>= function
      | `Eof -> Pipe.close w; Deferred.unit
      | `Ok curr -> 
        if Pipe.is_closed w then Deferred.unit
        else (
          Pipe.write w (create ~prev ~curr)
          >>= fun () -> loop curr
        )
    in
    don't_wait_for begin
      Pipe.read violations 
      >>= function
      | `Eof -> Pipe.close w; Deferred.unit
      | `Ok first -> 
        Pipe.write w (List.map ~f:create_unreportable first)
        >>= fun () ->
        loop first
    end;
    r
end

let throttle limit ~if_not_run =
  let last_time = ref Time.epoch in
  stage (fun f ->
      let now = Time.now () in
      let time_since =Time.diff now !last_time in
      if Time.Span.(time_since > limit) then (
        last_time := now;
        f ()
      ) else
        if_not_run ()
    )


let monitor_violations user conn ~stop =
  let r = Full_violation.get_pipe user conn in
  let alert_throttle =
    unstage (throttle (sec 30.) ~if_not_run:(fun () -> Deferred.unit))
  in
  upon stop (fun () -> Pipe.close_read r);
  Pipe.iter r ~f:(fun violations ->
      clear_screen ()
      >>= fun () ->
      Full_violation.print_list violations;
      Deferred.List.iter violations
        ~f:(fun {Full_violation. rule;exceeded_by;status=_;reportable} ->
            if not reportable then Deferred.unit
            else
              alert_throttle (fun () ->
                  printf ".\n";
                  (* Command to execute when popup is clicked *)
                  let execute =
                    sprintf "%s rules ack '%s'"
                      (Filename.realpath Sys.argv.(0))
                      (Rule.Name.to_string rule.name)
                  in
                  print_endline execute;
                  Notify.spawn
                    ~title:"Screentime exceeded"
                    ~sound:`Default
                    ~execute
                    (sprintf "'%s' exceeded by %s"
                       (Rule.Name.to_string rule.Rule.name)
                       (Time.Span.to_string exceeded_by))
                )))

let monitor_violations =
  Command.async_basic
    ~summary:"Monitor for violations"
    (shared_flags ())
    (fun user () -> setup_conn user (fun user conn ->
         monitor_violations user conn ~stop:(force on_term_signal)))

let ack =
  Command.async_basic
    ~summary:"Acknowledge a rule, and prevent futher warnings"
    Command.Spec.(
      shared_flags ()
      +> anon ("rule" %: Arg_type.create Rule.Name.of_string)
    )
    (fun user rule () -> setup_conn user (fun user conn ->
         Rpc.Rpc.dispatch_exn Protocol.acknowledge conn (user,rule)
       ))

let unack =
  Command.async_basic
    ~summary:"Unacknowledge a rule, and prevent futher warnings"
    Command.Spec.(
      shared_flags ()
      +> anon ("rule" %: Arg_type.create Rule.Name.of_string)
    )
    (fun user rule () -> setup_conn user (fun user conn ->
         Rpc.Rpc.dispatch_exn Protocol.unacknowledge conn (user,rule)
       ))

let command =
  Command.group
    ~summary:"Tools for interacting with screentime rules"
    [ "show"       , show_rules
    ; "edit"       , edit_rules
    ; "violations" , rule_violations
    ; "monitor"    , monitor_violations
    ; "ack"        , ack
    ; "unack"      , unack
    ]