diff --git a/include/mathfuncRewrite.hpp b/include/mathfuncRewrite.hpp index ac66288a47330f63255dd52e0780a62a6ec026c5..b72e0caa85f859724867a5018ea68d067a45abf2 100644 --- a/include/mathfuncRewrite.hpp +++ b/include/mathfuncRewrite.hpp @@ -19,4 +19,10 @@ std::unique_ptr sqrtTohypot(const std::unique_ptr &expr); // log(exp(x))⇒x OR exp(log(x))⇒x std::unique_ptr lex_x_Or_elx_x(const std::unique_ptr &expr); +// sqrt(x)*sqrt(y)⇒sqrt(x*y) +std::unique_ptr sqrtMult(const std::unique_ptr &expr); + +// sqrt(x)/sqrt(y)⇒sqrt(x/y) +std::unique_ptr sqrtDiv(const std::unique_ptr &expr); + #endif diff --git a/src/mathfuncRewrite.cpp b/src/mathfuncRewrite.cpp index 2cee79eec8765be2b1d2b7e5e9940a0e1ec3bd6b..b1d9d402f07cdbe21935b584abd625042b8bd7a7 100644 --- a/src/mathfuncRewrite.cpp +++ b/src/mathfuncRewrite.cpp @@ -1,6 +1,9 @@ #include "basic.hpp" #include "expandAST.hpp" #include "mathfuncRewrite.hpp" +#include +#include +#include "exprAuto.hpp" //===----------------------------------------------------------------------===// // Equivalent transformation of mathematical function @@ -206,4 +209,136 @@ std::unique_ptr lex_x_Or_elx_x(const std::unique_ptr &expr) } } return expr->Clone(); +} + +// sqrt(x)*sqrt(y) ==> sqrt(x*y) +std::unique_ptr sqrtMult(const std::unique_ptr &expr) +{ + if(expr == nullptr) + { + fprintf(stderr, "empty\n"); + return nullptr; + } + if(expr->type() == "Binary") + { + BinaryExprAST *binOp = dynamic_cast(expr.get()); + char op = binOp->getOp(); + std::string opStr(1, op); +#ifdef DEBUG + fprintf(stderr, "op: %s\n", opStr.c_str()); +#endif + + std::unique_ptr &lhs = binOp->getLHS(); + std::unique_ptr &rhs = binOp->getRHS(); + + const std::string exprTypeLHS = lhs->type(); + const std::string exprTypeRHS = rhs->type(); + + if(op == '*') + { + if((exprTypeLHS == "Call") && (exprTypeRHS == "Call")) + { + CallExprAST *callExprL = dynamic_cast(lhs.get()); + CallExprAST *callExprR = dynamic_cast(rhs.get()); + + std::string calleeL = (callExprL->getCallee()); + std::string calleeR = (callExprR->getCallee()); + + std::vector> &argsL = callExprL->getArgs(); //左表达式中函数的参数 + std::vector> &argsR = callExprR->getArgs(); //右表达式中函数的参数 + + std::vector> argsNew; +#ifdef DEBUG + fprintf(stderr, "call: %s\n", calleeL.c_str()); +#endif + if((calleeL == "sqrt") && (calleeR == "sqrt")) + { + if((argsL.size() == 1) && (argsR.size() == 1)) + {//取出左右表达式中的参数 + auto argL = std::move(argsL.at(0)); + auto argR = std::move(argsR.at(0)); + //存放转换后的表达式x*y + auto argsFinal =mulExpr(argL, argR); + argsNew.push_back(std::move(argsFinal)); + } + std::string calleeNew = "sqrt"; + std::unique_ptr exprFinal = std::make_unique(calleeNew, std::move(argsNew)); + return exprFinal; + } + else + return expr->Clone(); + } + else + return expr->Clone(); + } + else + return expr->Clone(); + } + return expr->Clone(); +} + +// sqrt(x)/sqrt(y) ==> sqrt(x/y) +std::unique_ptr sqrtDiv(const std::unique_ptr &expr) +{ + if(expr == nullptr) + { + fprintf(stderr, "empty\n"); + return nullptr; + } + if(expr->type() == "Binary") + { + BinaryExprAST *binOp = dynamic_cast(expr.get()); + char op = binOp->getOp(); + std::string opStr(1, op); +#ifdef DEBUG + fprintf(stderr, "op: %s\n", opStr.c_str()); +#endif + + std::unique_ptr &lhs = binOp->getLHS(); + std::unique_ptr &rhs = binOp->getRHS(); + + const std::string exprTypeLHS = lhs->type(); + const std::string exprTypeRHS = rhs->type(); + + if(op == '/') + { + if((exprTypeLHS == "Call") && (exprTypeRHS == "Call")) + { + CallExprAST *callExprL = dynamic_cast(lhs.get()); + CallExprAST *callExprR = dynamic_cast(rhs.get()); + + std::string calleeL = (callExprL->getCallee()); + std::string calleeR = (callExprR->getCallee()); + + std::vector> &argsL = callExprL->getArgs(); //左表达式中函数的参数 + std::vector> &argsR = callExprR->getArgs(); //右表达式中函数的参数 + + std::vector> argsNew; +#ifdef DEBUG + fprintf(stderr, "call: %s\n", calleeL.c_str()); +#endif + if((calleeL == "sqrt") && (calleeR == "sqrt")) + { + if((argsL.size() == 1) && (argsR.size() == 1)) + {//取出左右表达式中的参数 + auto argL = std::move(argsL.at(0)); + auto argR = std::move(argsR.at(0)); + //存放转换后的表达式x/y + auto argsFinal =divExpr(argL, argR); + argsNew.push_back(std::move(argsFinal)); + } + std::string calleeNew = "sqrt"; + std::unique_ptr exprFinal = std::make_unique(calleeNew, std::move(argsNew)); + return exprFinal; + } + else + return expr->Clone(); + } + else + return expr->Clone(); + } + else + return expr->Clone(); + } + return expr->Clone(); } \ No newline at end of file