Commits

cacol89 committed 1846291

enabled having clauses

Comments (0)

Files changed (4)

     "FROM "^(String.concat ", " aliases)
 
 (** Given an aggregate function name, checks if it is supported by YADI and
- * returns it i upper case*)
+ * returns it*)
 let check_agg_function fn =
     let allowed = ["MAX";"MIN";"SUM";"AVG";"COUNT"] in
-    let upc = String.uppercase fn in
-    if List.mem upc allowed then upc
+    if List.mem fn allowed then fn
     else raise (Yadi_error (
-        "Aggregate function '"^upc^"' is not supported, "^
+        "Aggregate function '"^fn^"' is not supported, "^
         "allowed functions are: "^(String.concat ", " allowed)
     ))
 
+(*Given a variable name, returns the name of a EDB/IDB column
+ * that defines it, or if it is equal to a constant, the
+ * value of the constant.*)
+let vname_to_col (vt:vartab) (eqt:eqtab) key vname =
+    (*If the variable appears in a positive rterm, the value
+     * is the name of the respective rterm's table column*)
+    if Hashtbl.mem vt vname
+        then List.hd (Hashtbl.find vt vname)
+    (*If the variable does not appear in a positive rterm, but
+     * it does in an equality value, then the value is the eq's
+     * constant, the var has to be removed from the eqtab*)
+    else if Hashtbl.mem eqt vname
+        then string_of_const (eqt_extract eqt vname)
+    (*Else, the query is unsafe or inclomplete*)
+    else raise (Yadi_error (
+            "Predicate "^(string_of_symtkey key)^
+            " is unsafe, variable "^vname^" is not in a positive "^
+            "goal or strict equality relation."
+        )
+    )
 
 (** Given the head of the rule, the vartab, and te eqtab, returns the code that
  * must be in the select clause. All columns are aliased as col0, col1, ...*)
 let get_select_clause (vt:vartab) (eqt:eqtab) rterm =
     let vlst = get_rterm_varlist rterm in 
+    let key = symtkey_of_rterm rterm in
     if vlst = [] then
         raise (Yadi_error
             ("Predicate "^(get_rterm_predname rterm)^
             " has arity 0, which is not allowed"))
     else
-    let vname_to_col vname =
-        (*If the variable appears in a positive rterm, the value
-         * is the name of the respective rterm's table column*)
-        if Hashtbl.mem vt vname
-            then List.hd (Hashtbl.find vt vname)
-        (*If the variable does not appear in a positive rterm, but
-         * it does in an equality value, then the value is the eq's
-         * constant, the var has to be removed from the eqtab*)
-        else if Hashtbl.mem eqt vname
-            then string_of_const (eqt_extract eqt vname)
-        (*Else, the query is unsafe or inclomplete*)
-        else raise (Yadi_error (
-                "Predicate "^(string_of_symtkey (symtkey_of_rterm rterm))^
-                " is unsafe, variable "^vname^" is not in a positive "^
-                "goal or strict equality relation."
-            )
-        ) in
+    (*Transform variables to column names. Treat namedVars and
+     * aggregates differently*)
     let var_value v = match v with
         NamedVar _ | NumberedVar _ ->
-            vname_to_col (string_of_var v)
+            vname_to_col vt eqt key (string_of_var v)
         | AggVar (fn,vn) ->
-            (check_agg_function fn)^"("^(vname_to_col vn)^")"
+            (check_agg_function fn)^"("^(vname_to_col vt eqt key vn)^")"
         | _ -> invalid_arg ("not-expected vartype in head of predicate"^
-            (string_of_symtkey (symtkey_of_rterm rterm)))
+            (string_of_symtkey key))
     in
     let cols = List.map var_value vlst in
+    (*Create aliases*)
     let rec alias ind = function
         | [] -> ""
         | [col] -> col^" AS col"^(string_of_int ind)
     let feqt = Hashtbl.fold eq_const eqt [] in
     (*Transform the inequalities in the list for strings of the form
      * "CName op value" *)
+    let ineq_tuples = List.map extract_ineq_tuple ineq in
     let ineq_const (op,var,value) acc =
         let vname = string_of_var var in
         let cname = List.hd (Hashtbl.find vt vname) in
         (cname^" "^op^" "^(string_of_const value))::acc in
-    let fineq = List.fold_right ineq_const ineq [] in
+    let fineq = List.fold_right ineq_const ineq_tuples [] in
     (*Transform the negated rterms into SQL*)
     let fnrt = sql_of_negated_rterms idb vt cnt eqt neg_rt in
     (*merge all constraints*)
 
 (** Generates the SQL that correspond to aggregation in a rule,
  * this corresponds to GROUP BY and HAVING clauses.
+ * 
  * The GROUP BY clause will be comprised of the columns in the
  * resulting table that are not aggregates (obviously), nor constants.
- * 
+ * The HAVING clause will correspond to comparissons with aggregates.
+ *
  * If the predicate's head do not contain aggregation functions, nothing is
  * returned. If this condition is met but aggregate functions appear on the
  * rule's body, an error is raised.
  * 
+ * If there are comparissons with respect to aggregates that do not appear
+ * in the rule head, an error is also raised.
+ *
  * PRECONDITION: it is assumed that NumberedVars in the rule's head correspond
  * to constants.
  * *)
-let get_aggregation_sql (cnt:colnamtab) rule =
-    let head = rule_head rule in
+let get_aggregation_sql (vt:vartab) (cnt:colnamtab) head agg_eqs agg_ineqs =
     let vars = get_rterm_varlist head in
+    let key = symtkey_of_rterm head in
+    (*Merge the equalities and inequalities in a simple list*)
+    let eq_t = List.map extract_eq_tuple agg_eqs in
+    let aug_eq_t = List.map (fun (x,y) -> ("=",x,y)) eq_t in
+    let ieq_t = List.map extract_ineq_tuple agg_ineqs in
+    let comparisons = aug_eq_t@ieq_t in
     (*Check if the rule has aggregation*)
     let is_agg = List.exists is_aggvar vars in
-    if not is_agg then "" else
-    let key = symtkey_of_rule rule in
+    if not is_agg then
+        if comparisons = [] then ""
+        else raise (Yadi_error (
+            "Predicate "^(string_of_symtkey key)^
+            " contains comparisons of aggregates but defines no "^
+            "aggregations in its head"))
+    else
     let cols = Hashtbl.find cnt key in
     (*Calculate the GROUP BY clause*)
     let group_var acc col = function
     let group_by_sql =
         if grp_cols = [] then ""
         else ("GROUP BY "^(String.concat ", " grp_cols)) in
-    group_by_sql
+    (*Calculate the HAVING clause*)
+    (*Extract the available aggregations in the head, and place them
+     * in a list, which values will be the function applied to a column-name*)
+    let av_aggs = Hashtbl.create 100 in
+    let fake_eqt:eqtab = Hashtbl.create 100 in
+    let insert_agg = function
+        | AggVar (fn,vn) ->
+            let col = vname_to_col vt fake_eqt key vn in
+            Hashtbl.add av_aggs (fn,vn) (fn^"("^col^")")
+        | _ -> () in
+    List.iter insert_agg vars;
+    (*Build the contraints and check for unavailable aggregates*)
+    let agg_var_col agv =
+        let tuple = extract_aggvar_tuple agv in
+        if Hashtbl.mem av_aggs tuple then Hashtbl.find av_aggs tuple
+        else raise (Yadi_error (
+            "Predicate "^(string_of_symtkey key)^" contains comparisons of "^
+            "aggregates that are not defined in its head"
+        )) in
+    let comp_const (op,var,const) =
+        (agg_var_col var)^" "^op^" "^(string_of_const const) in 
+    let comp_sql = List.map comp_const comparisons in
+    let having_sql = if comp_sql = [] then "" else
+        "HAVING "^(String.concat " AND " comp_sql) in
+    group_by_sql^" "^having_sql
 
 (** Takes a list of terms and splits them in positive rterms,
  * negative terms, equalities, and inequalities*)
     let rec split t (pos,neg,eq,inq) = match t with
         | Rel rt -> (rt::pos,neg,eq,inq)
         | Not rt -> (pos,rt::neg,eq,inq)
-        | Equal (x,y) -> (pos,neg,(x,y)::eq,inq) 
-        | Ineq (op,x,y) -> (pos,neg,eq,(op,x,y)::inq) in
+        | Equal _ -> (pos,neg,t::eq,inq) 
+        | Ineq _ -> (pos,neg,eq,t::inq) in
     List.fold_right split terms ([],[],[],[])
 
 (** Takes a rule and makes a SQL query that calculates its result*)
 let sql_of_rule (idb:symtable) (cnt:colnamtab) rule =
     let head = rule_head rule in
     let body = rule_body rule in
-    (*Extract positive rterms from the rule*)
-    let (p_rt,n_rt,equalities,ineq) = split_terms body in
+    (*Split terms in the rule's body. Separate equalities
+     * and inequalities in variable and aggregates relations.*)
+    let (p_rt,n_rt,all_eqs,all_ineqs) = split_terms body in
+    let (agg_eqs,eqs) = List.partition is_agg_equality all_eqs in
+    let (agg_ineqs,ineqs) = List.partition is_agg_inequality all_ineqs in
     (*Build vartab, and eqtab for select and where clauses*)
     let vt = build_vartab cnt p_rt in
-    let eqt = build_eqtab equalities in
-    let select_sql = get_select_clause vt eqt head in
+    let eqtb = build_eqtab eqs in
+    let select_sql = get_select_clause vt eqtb head in
     let from_sql = get_from_clause idb p_rt in
-    let where_sql = get_where_clause idb vt cnt eqt ineq n_rt in
-    let agg_sql = get_aggregation_sql cnt rule in
+    let where_sql = get_where_clause idb vt cnt eqtb ineqs n_rt in
+    let agg_sql = get_aggregation_sql vt cnt head agg_eqs agg_ineqs in
     String.concat " " [select_sql;from_sql;where_sql;agg_sql]
 
 (**Takes a list of similar rules (same head) and generates the SQL statement
         List.fold_left extract_rterm [] t
     | Query _    -> invalid_arg "function get_all_rule_rterms called with a query"
 
+let extract_eq_tuple = function
+    | Equal (v,c) -> (v,c)
+    | _ -> invalid_arg "function extract_eq_tuple called without an equality"
+
+let extract_ineq_tuple = function
+    | Ineq (s,v,c) -> (s,v,c)
+    | _ -> invalid_arg "function extract_ineq_tuple called without an inequality"
+
+let extract_aggvar_tuple = function
+    | AggVar (fn,vn) -> (fn,vn)
+    | _ -> invalid_arg "function extract_aggvar_tuple called without an aggregated var"
+
 (****************************************************
  *
  *  AST check / transformation functions
     | AggVar _ -> true
     | _ -> false
 
+let is_agg_equality = function
+    | Equal (AggVar _ , _ ) -> true
+    | Equal _ -> false
+    | _ -> invalid_arg "function is_agg_equality called without an equality"
+
+let is_agg_inequality = function
+    | Ineq (_ , AggVar _ , _) -> true
+    | Ineq _ -> false
+    | _ -> invalid_arg "function is_agg_inequality called without an equality"
+
 (****************************************************
  *
  *  String operations
   ;
 
   equation:	
-  | VARNAME EQ constant						{ Equal (NamedVar $1, $3) }
-  | VARNAME NE constant						{ Ineq ("<>", NamedVar $1, $3) }
-  | VARNAME LT constant						{ Ineq ("<", NamedVar $1, $3) }
-  | VARNAME GT constant						{ Ineq (">", NamedVar $1, $3) }
-  | VARNAME LE constant						{ Ineq ("<=", NamedVar $1, $3) }
-  | VARNAME GE constant						{ Ineq (">=", NamedVar $1, $3) }
+  | var_or_agg EQ constant	{ Equal ($1, $3) }
+  | var_or_agg NE constant	{ Ineq ("<>", $1, $3) }
+  | var_or_agg LT constant	{ Ineq ( "<", $1, $3) }
+  | var_or_agg GT constant	{ Ineq ( ">", $1, $3) }
+  | var_or_agg LE constant	{ Ineq ("<=", $1, $3) }
+  | var_or_agg GE constant	{ Ineq (">=", $1, $3) }
+  ;
+
+  var_or_agg:
+  | VARNAME     { NamedVar $1 }
+  | aggregate   { $1 }
   ;
 
   constant:
   | VARNAME     { NamedVar $1 }
   | ANONVAR     { AnonVar }
   | constant    { ConstVar $1 }
-  | agregate    { $1 }
+  | aggregate    { $1 }
   ;
 
-  agregate:
-  | VARNAME LPAREN VARNAME RPAREN       { AggVar ($1,$3) }
+  aggregate:
+  | VARNAME LPAREN VARNAME RPAREN       { AggVar (String.uppercase $1,$3) }
   ;

src/yadi_utils.ml

  * must be satisfied by the variables*) 
 type eqtab = (string,const) Hashtbl.t
 
-(** Given a list of (var,const) tuples, returns an eqtab with
- * the equality relations as var = value*)
-let build_eqtab tuples =
+(** Given a list of equality ASTs, returns an eqtab with
+ * the equality relations as var = value.
+ * PRECONDITION: There should not be aggregate equalities
+ * in the provided list.*)
+let build_eqtab eqs =
+    let tuples = List.map extract_eq_tuple eqs in
     let hs:eqtab = Hashtbl.create 100 in
     let add_rel (var,c) = match var with
         NamedVar _ | NumberedVar _ -> Hashtbl.add hs (string_of_var var) c