Source

z3 / src / ast / rewriter / factor_rewriter.cpp

Full commit
/*++
Copyright (c) 2011 Microsoft Corporation

Module Name:

    factor_rewriter.cpp

Abstract:

    Rewriting utilities for factoring polynomials in equations,
    and inequalities.

Author:

    Nikolaj (nbjorner) 2011-19-05

Notes:

--*/

#include"factor_rewriter.h"
#include"ast_pp.h"
#include"rewriter_def.h"

factor_rewriter::factor_rewriter(ast_manager & m): m_manager(m), m_arith(m), m_factors(m) {
}

br_status factor_rewriter::mk_app_core(
    func_decl * f, unsigned num_args, expr * const * args, expr_ref & result) {

    if (m().is_eq(f)) { SASSERT(num_args == 2);  return mk_eq(args[0], args[1], result); }

    if(f->get_family_id() == a().get_family_id()) {
        switch (f->get_decl_kind()) {
        case OP_LE:   SASSERT(num_args == 2); return mk_le(args[0], args[1], result);
        case OP_GE:   SASSERT(num_args == 2); return mk_ge(args[0], args[1], result);
        case OP_LT:   SASSERT(num_args == 2); return mk_lt(args[0], args[1], result);
        case OP_GT:   SASSERT(num_args == 2); return mk_gt(args[0], args[1], result);
        default: return BR_FAILED;
        }
    }
    return BR_FAILED;
}

br_status factor_rewriter::mk_eq(expr * arg1, expr * arg2, expr_ref & result) {
    if (!a().is_real(arg1) && !m_arith.is_int(arg1)) {
        return BR_FAILED;
    }
    mk_adds(arg1, arg2);
    mk_muls();
    if (m_muls.empty()) {
        result = m().mk_true();
        return BR_DONE;
    }
    if (!extract_factors()) {
        TRACE("factor_rewriter", tout << mk_pp(arg1, m()) << " = " << mk_pp(arg2, m()) << "\n";);
        return BR_FAILED;
    }
    powers_t::iterator it = m_powers.begin(), end = m_powers.end();
    expr_ref_vector eqs(m());
    for(; it != end; ++it) {
        expr* e = it->m_key;
        eqs.push_back(m().mk_eq(e, a().mk_numeral(rational(0), m().get_sort(e))));  
    }
    result = m().mk_or(eqs.size(), eqs.c_ptr());    
    return BR_DONE;
}

br_status factor_rewriter::mk_le(expr * arg1, expr * arg2, expr_ref & result) {
    mk_adds(arg1, arg2);
    mk_muls();
    if (m_muls.empty()) {
        result = m().mk_true();
        return BR_DONE;
    }
    if (!extract_factors()) {
        TRACE("factor_rewriter", tout << mk_pp(arg1, m()) << " <= " << mk_pp(arg2, m()) << "\n";);
        return BR_FAILED;
    }

    // a^2 * b^3 * c <= 0 -> 
    // a = 0 \/ (b = 0 \/ b > 0 & c <= 0 \/ b < 0 & c >= 0)
    // 

    expr_ref neg(m());
    expr_ref_vector eqs(m());
    mk_is_negative(neg, eqs);
    eqs.push_back(neg);
    result = m().mk_or(eqs.size(), eqs.c_ptr());
    TRACE("factor_rewriter", 
          tout << mk_pp(arg1, m()) << " <= " << mk_pp(arg2, m()) << "\n";
          tout << mk_pp(result.get(), m()) << "\n";);
    return BR_DONE;
}

br_status factor_rewriter::mk_lt(expr * arg1, expr * arg2, expr_ref & result) {
    mk_adds(arg1, arg2);
    mk_muls();
    if (m_muls.empty()) {
        result = m().mk_false();
        return BR_DONE;
    }
    if (!extract_factors()) {
        TRACE("factor_rewriter", tout << mk_pp(arg1, m()) << " < " << mk_pp(arg2, m()) << "\n";);
        return BR_FAILED;
    }
    // a^2 * b^3 * c < 0 -> 
    // a != 0 /\ (b > 0 & c < 0 \/ b < 0 & c > 0)
    // 

    expr_ref neg(m());
    expr_ref_vector eqs(m());
    mk_is_negative(neg, eqs);
    for (unsigned i = 0; i < eqs.size(); ++i) {
        eqs[i] = m().mk_not(eqs[i].get());
    }
    eqs.push_back(neg);
    result = m().mk_and(eqs.size(), eqs.c_ptr());
    TRACE("factor_rewriter", tout << mk_pp(result.get(), m()) << "\n";);
    return BR_DONE;
}

void factor_rewriter::mk_is_negative(expr_ref& result, expr_ref_vector& eqs) {
    powers_t::iterator it = m_powers.begin(), end = m_powers.end();
    SASSERT(m_powers.size() >= 1);
    SASSERT(it != end);
    expr_ref neg0(m()), neg(m()), pos0(m()), pos(m()), tmp(m());
    expr* e = it->m_key;
    expr_ref zero(a().mk_numeral(rational(0), m().get_sort(e)), m());
    expr_ref_vector conjs(m());
    pos0 = m().mk_true();
    neg0 = m().mk_false();
    for(; it != end; ++it) {
        e = it->m_key;
        eqs.push_back(m().mk_eq(zero, e));
        if (!even(it->m_value)) {
            pos = a().mk_lt(zero, e);
            neg = a().mk_lt(e, zero);
            if (m().is_false(neg0)) {
                neg0 = neg;
                pos0 = pos;
            }
            else {
                tmp = m().mk_or(m().mk_and(pos, pos0), m().mk_and(neg, neg0));
                neg0 = m().mk_or(m().mk_and(neg, pos0), m().mk_and(pos, neg0));
                pos0 = tmp;
            }
        }
    }
    result = neg0;
}

// convert arg1 - arg2 into 
// sum of monomials
// m_adds: sum of products.
// m_muls: list of products
void factor_rewriter::mk_adds(expr* arg1, expr* arg2) {
    m_adds.reset();
    m_adds.push_back(std::make_pair(arg1, true));
    m_adds.push_back(std::make_pair(arg2, false));
    rational k;
    for (unsigned i = 0; i < m_adds.size();) {
        bool sign = m_adds[i].second;
        expr* _e   = m_adds[i].first;

        TRACE("factor_rewriter", tout << i << " " << mk_pp(_e, m_manager) << "\n";);

        if (!is_app(_e)) {
            ++i;
            continue;
        }
        app* e = to_app(_e);
        if (a().is_add(e) && e->get_num_args() > 0) {
            m_adds[i].first = e->get_arg(0);
            for (unsigned j = 1; j < e->get_num_args(); ++j) {
                m_adds.push_back(std::make_pair(e->get_arg(j),sign));
            }            
        }
        else if (a().is_sub(e) && e->get_num_args() > 0) {
            m_adds[i].first = e->get_arg(0);
            for (unsigned j = 1; j < e->get_num_args(); ++j) {
                m_adds.push_back(std::make_pair(e->get_arg(j),!sign));
            }
        }
        else if (a().is_uminus(e)) {
            m_adds[i].first = e->get_arg(0);
            m_adds[i].second = !sign;
        }
        else if (a().is_numeral(e, k) && k.is_zero()) {
            unsigned sz = m_adds.size();
            m_adds[i] = m_adds[sz-1];
            m_adds.resize(sz-1);
        }
        else {
            ++i;
        }
    }
    TRACE("factor_rewriter",
        for (unsigned i = 0; i < m_adds.size(); ++i) {
            if (!m_adds[i].second) tout << "-"; else tout << "+";
            tout << mk_pp(m_adds[i].first, m()) << " ";
        }
        tout << "\n";
        );
}

void factor_rewriter::mk_muls() {
    m_muls.reset();
    for (unsigned i = 0; i < m_adds.size(); ++i) {
        m_muls.push_back(ptr_vector<expr>());
        m_muls.back().push_back(m_adds[i].first);
        mk_expand_muls(m_muls.back());
        if (m_muls.back().empty()) {
            m_muls.pop_back();
            m_adds.erase(m_adds.begin() + i);
            --i;
        }
    }
    TRACE("factor_rewriter", 
        for (unsigned i = 0; i < m_muls.size(); ++i) {
            for (unsigned j = 0; j < m_muls[i].size(); ++j) {
                tout << mk_pp(m_muls[i][j], m()) << " ";
            }
            tout << "\n";
        }
        tout << "\n";
        );
}

void factor_rewriter::mk_expand_muls(ptr_vector<expr>& muls) {
    for (unsigned i = 0; i < muls.size(); ) {
        expr* _e   = muls[i];
        if (!is_app(_e)) {
            ++i;
            continue;
        }
        app* e = to_app(_e);
        if (a().is_mul(e) && e->get_num_args() > 0) {
            muls[i] = e->get_arg(0);
            for (unsigned j = 1; j < e->get_num_args(); ++j) {
                muls.push_back(e->get_arg(j));
            }
        }       
        else {
            ++i;
        }
    }
}

bool factor_rewriter::extract_factors() {
    m_factors.reset();
    unsigned_vector pos;
    expr* e;
    SASSERT(!m_muls.empty());
    if (m_muls.size() == 1) {
        if (m_muls[0].size() > 1) {
            m_factors.append(m_muls[0].size(), m_muls[0].c_ptr());
            if (!m_adds[0].second) {
                bool found_numeral = false;
                sort* s = m().get_sort(m_muls[0][0]);
                rational v;
                for (unsigned i = 0; !found_numeral && i < m_factors.size(); ++i) {
                    if (a().is_numeral(m_factors[i].get(), v)) {
                        m_factors[i] = a().mk_numeral(-v, s);
                        found_numeral = true;
                    }
                }
                if (!found_numeral) {
                    m_factors.push_back(a().mk_numeral(rational(-1),s));
                }
            }
            collect_powers();
            return true;
        }
        return false;
    }
    for (unsigned i = 0; i < m_muls[0].size(); ++i) {
        pos.reset();
        pos.push_back(i);
        e = m_muls[0][i];
        bool ok = true;
        for (unsigned j = 1; ok && j < m_muls.size(); ++j) {
            ok = false;
            unsigned k = 0;
            for (k = 0; !ok && k < m_muls[j].size(); ++k) {
                ok = m_muls[j][k] == e;
            }
            pos.push_back(k-1);
        }
        if (ok) {
            SASSERT(pos.size() == m_muls.size());
            m_factors.push_back(e);
            for (unsigned j = 0; j < pos.size(); ++j) {
                m_muls[j].erase(m_muls[j].begin() + pos[j]);
            }
            --i;
        }
    }
    if (m_factors.empty()) {
        return false;
    }
    SASSERT(m_muls.size() == m_adds.size());
    expr_ref_vector trail(m());
    sort* s = m().get_sort(m_factors[0].get());
    for (unsigned i = 0; i < m_adds.size(); ++i) {
        switch(m_muls[i].size()) {
        case 0: 
            e = a().mk_numeral(rational(1), s);
            break;
        case 1: 
            e = m_muls[i][0];
            break;
        default: 
            e = a().mk_mul(m_muls[i].size(), m_muls[i].c_ptr()); 
            break;
        }
        if (!m_adds[i].second) {
            e = a().mk_uminus(e);
        }
        trail.push_back(e);
    }       
    switch(trail.size()) {
    case 0: 
        break;
    case 1: 
        m_factors.push_back(trail[0].get());
        break;
    default: 
        m_factors.push_back(a().mk_add(trail.size(), trail.c_ptr()));
        break;
    }
    TRACE("factor_rewriter",
        for (unsigned i = 0; i < m_factors.size(); ++i) {
            tout << mk_pp(m_factors[i].get(), m()) << " ";
        }
        tout << "\n";
        );
    collect_powers();
    return true;
}

void factor_rewriter::collect_powers() {
    m_powers.reset();
    for (unsigned i = 0; i < m_factors.size(); ++i) {
        obj_map<expr,unsigned>::obj_map_entry* entry = m_powers.insert_if_not_there2(m_factors[i].get(), 0);
        if (entry) {
            ++(entry->get_data().m_value);
        }
    }
}

template class rewriter_tpl<factor_rewriter_cfg>;