diff --git a/src/DataTree.cc b/src/DataTree.cc index 3574b557415b9bee68d72db853326b0fc9651b14..1887d62721ce819a01e732151dfbf79d07502971 100644 --- a/src/DataTree.cc +++ b/src/DataTree.cc @@ -243,6 +243,16 @@ DataTree::AddTimes(expr_t iArg1, expr_t iArg2) return AddUMinus(iArg1); else if (iArg1 != Zero && iArg1 != One && iArg2 != Zero && iArg2 != One) { + // Simplify (x/y)*y in x + if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1); + barg1 && barg1->op_code == BinaryOpcode::divide && barg1->arg2 == iArg2) + return barg1->arg1; + + // Simplify y*(x/y) in x + if (auto barg2 = dynamic_cast<BinaryOpNode *>(iArg2); + barg2 && barg2->op_code == BinaryOpcode::divide && barg2->arg2 == iArg1) + return barg2->arg1; + // To treat commutativity of "*" // Nodes iArg1 and iArg2 are sorted by index if (iArg1->idx > iArg2->idx) @@ -283,6 +293,16 @@ DataTree::AddDivide(expr_t iArg1, expr_t iArg2) noexcept(false) barg2 && barg2->op_code == BinaryOpcode::divide && barg2->arg1 == One) return AddTimes(iArg1, barg2->arg2); + // Simplify (x*y)/y and (y*x)/y in x + if (auto barg1 = dynamic_cast<BinaryOpNode *>(iArg1); + barg1 && barg1->op_code == BinaryOpcode::times) + { + if (barg1->arg2 == iArg2) + return barg1->arg1; + if (barg1->arg1 == iArg2) + return barg1->arg2; + } + return AddBinaryOp(iArg1, BinaryOpcode::divide, iArg2); }