From ae0a91256a6bbbdbb712c9f3473ac1945a303b22 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtan@dynare.org>
Date: Mon, 15 Jul 2019 12:18:26 -0400
Subject: [PATCH] add cubic root to dynare language

---
 src/CodeInterpreter.hh |  1 +
 src/DataTree.cc        | 12 ++++++++++++
 src/DataTree.hh        |  4 +++-
 src/DynareBison.yy     |  4 +++-
 src/DynareFlex.ll      |  1 +
 src/ExprNode.cc        | 28 ++++++++++++++++++++++++++++
 src/ParsingDriver.cc   |  6 ++++++
 src/ParsingDriver.hh   |  2 ++
 8 files changed, 56 insertions(+), 2 deletions(-)

diff --git a/src/CodeInterpreter.hh b/src/CodeInterpreter.hh
index 249e4680..9841a497 100644
--- a/src/CodeInterpreter.hh
+++ b/src/CodeInterpreter.hh
@@ -191,6 +191,7 @@ enum class UnaryOpcode
     asinh,
     atanh,
     sqrt,
+    cbrt,
     abs,
     sign,
     steadyState,
diff --git a/src/DataTree.cc b/src/DataTree.cc
index 5f7f837a..4349c525 100644
--- a/src/DataTree.cc
+++ b/src/DataTree.cc
@@ -35,6 +35,7 @@ DataTree::initConstants()
   Zero = AddNonNegativeConstant("0");
   One = AddNonNegativeConstant("1");
   Two = AddNonNegativeConstant("2");
+  Three = AddNonNegativeConstant("3");
 
   MinusOne = AddUMinus(One);
 
@@ -517,6 +518,17 @@ DataTree::AddSqrt(expr_t iArg1)
     return Zero;
 }
 
+expr_t
+DataTree::AddCbrt(expr_t iArg1)
+{
+  if (iArg1 == Zero)
+    return Zero;
+  else if (iArg1 == One)
+    return One;
+  else
+    return AddUnaryOp(UnaryOpcode::cbrt, iArg1);
+}
+
 expr_t
 DataTree::AddAbs(expr_t iArg1)
 {
diff --git a/src/DataTree.hh b/src/DataTree.hh
index 6e7da7a2..9b62ad6a 100644
--- a/src/DataTree.hh
+++ b/src/DataTree.hh
@@ -132,7 +132,7 @@ public:
   DataTree & operator=(DataTree &&) = delete;
 
   //! Some predefined constants
-  expr_t Zero, One, Two, MinusOne, NaN, Infinity, MinusInfinity, Pi;
+  expr_t Zero, One, Two, Three, MinusOne, NaN, Infinity, MinusInfinity, Pi;
 
   //! Raised when a local parameter is declared twice
   class LocalVariableException
@@ -221,6 +221,8 @@ public:
   expr_t AddAtanh(expr_t iArg1);
   //! Adds "sqrt(arg)" to model tree
   expr_t AddSqrt(expr_t iArg1);
+  //! Adds "cbrt(arg)" to model tree
+  expr_t AddCbrt(expr_t iArg1);
   //! Adds "abs(arg)" to model tree
   expr_t AddAbs(expr_t iArg1);
   //! Adds "sign(arg)" to model tree
diff --git a/src/DynareBison.yy b/src/DynareBison.yy
index 1a668e36..7e501a74 100644
--- a/src/DynareBison.yy
+++ b/src/DynareBison.yy
@@ -127,7 +127,7 @@ class ParsingDriver;
 %precedence UMINUS UPLUS
 %nonassoc POWER
 %token EXP LOG LN LOG10 SIN COS TAN ASIN ACOS ATAN ERF DIFF ADL AUXILIARY_MODEL_NAME
-%token SQRT NORMCDF NORMPDF STEADY_STATE EXPECTATION VAR_ESTIMATION
+%token SQRT CBRT NORMCDF NORMPDF STEADY_STATE EXPECTATION VAR_ESTIMATION
 /* GSA analysis */
 %token DYNARE_SENSITIVITY MORRIS STAB REDFORM PPRIOR PRIOR_RANGE PPOST ILPTAU MORRIS_NLIV
 %token MORRIS_NTRA NSAM LOAD_REDFORM LOAD_RMSE LOAD_STAB ALPHA2_STAB LOGTRANS_REDFORM THRESHOLD_REDFORM
@@ -800,6 +800,8 @@ expression : '(' expression ')'
              { $$ = driver.add_atan($3); }
            | SQRT '(' expression ')'
              { $$ = driver.add_sqrt($3); }
+           | CBRT '(' expression ')'
+             { $$ = driver.add_cbrt($3); }
            | ABS '(' expression ')'
              { $$ = driver.add_abs($3); }
            | SIGN '(' expression ')'
diff --git a/src/DynareFlex.ll b/src/DynareFlex.ll
index fb356983..2bf32a95 100644
--- a/src/DynareFlex.ll
+++ b/src/DynareFlex.ll
@@ -830,6 +830,7 @@ DATE -?[0-9]+([ya]|m([1-9]|1[0-2])|q[1-4]|w([1-9]{1}|[1-4][0-9]|5[0-2]))
 <DYNARE_STATEMENT,DYNARE_BLOCK>acos {return token::ACOS;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>atan {return token::ATAN;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>sqrt {return token::SQRT;}
+<DYNARE_STATEMENT,DYNARE_BLOCK>cbrt {return token::CBRT;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>max {return token::MAX;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>min {return token::MIN;}
 <DYNARE_STATEMENT,DYNARE_BLOCK>abs {return token::ABS;}
diff --git a/src/ExprNode.cc b/src/ExprNode.cc
index a9ff7da6..2082de49 100644
--- a/src/ExprNode.cc
+++ b/src/ExprNode.cc
@@ -2171,6 +2171,10 @@ UnaryOpNode::composeDerivatives(expr_t darg, int deriv_id)
     case UnaryOpcode::sqrt:
       t11 = datatree.AddPlus(this, this);
       return datatree.AddDivide(darg, t11);
+    case UnaryOpcode::cbrt:
+      t11 = datatree.AddPower(arg, datatree.AddDivide(datatree.Two, datatree.Three));
+      t12 = datatree.AddTimes(datatree.Three, t11);
+      return datatree.AddDivide(darg, t12);
     case UnaryOpcode::abs:
       t11 = datatree.AddSign(arg);
       return datatree.AddTimes(t11, darg);
@@ -2313,6 +2317,7 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
       case UnaryOpcode::atanh:
         return cost + 350;
       case UnaryOpcode::sqrt:
+      case UnaryOpcode::cbrt:
       case UnaryOpcode::abs:
         return cost + 570;
       case UnaryOpcode::steadyState:
@@ -2361,6 +2366,7 @@ UnaryOpNode::cost(int cost, bool is_matlab) const
       case UnaryOpcode::atanh:
         return cost + 150;
       case UnaryOpcode::sqrt:
+      case UnaryOpcode::cbrt:
       case UnaryOpcode::abs:
         return cost + 90;
       case UnaryOpcode::steadyState:
@@ -2500,6 +2506,9 @@ UnaryOpNode::writeJsonAST(ostream &output) const
     case UnaryOpcode::sqrt:
       output << "sqrt";
       break;
+    case UnaryOpcode::cbrt:
+      output << "cbrt";
+      break;
     case UnaryOpcode::abs:
       output << "abs";
       break;
@@ -2618,6 +2627,9 @@ UnaryOpNode::writeJsonOutput(ostream &output,
     case UnaryOpcode::sqrt:
       output << "sqrt";
       break;
+    case UnaryOpcode::cbrt:
+      output << "cbrt";
+      break;
     case UnaryOpcode::abs:
       output << "abs";
       break;
@@ -2774,6 +2786,9 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
     case UnaryOpcode::sqrt:
       output << "sqrt";
       break;
+    case UnaryOpcode::cbrt:
+      output << "cbrt";
+      break;
     case UnaryOpcode::abs:
       output << "abs";
       break;
@@ -2957,6 +2972,8 @@ UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) noexcept(false)
       return (atanh(v));
     case UnaryOpcode::sqrt:
       return (sqrt(v));
+    case UnaryOpcode::cbrt:
+      return (cbrt(v));
     case UnaryOpcode::abs:
       return (abs(v));
     case UnaryOpcode::sign:
@@ -3107,6 +3124,9 @@ UnaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>>
         case UnaryOpcode::sqrt:
           List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Two);
           return { 1, nullptr };
+        case UnaryOpcode::cbrt:
+          List_of_Op_RHS.emplace_back(static_cast<int>(BinaryOpcode::power), nullptr, datatree.Three);
+          return { 1, nullptr };
         case UnaryOpcode::abs:
           return { 2, nullptr };
         case UnaryOpcode::sign:
@@ -3159,6 +3179,8 @@ UnaryOpNode::normalizeEquation(int var_endo, vector<tuple<int, expr_t, expr_t>>
           return { 0, datatree.AddAtanh(New_expr_t) };
         case UnaryOpcode::sqrt:
           return { 0, datatree.AddSqrt(New_expr_t) };
+        case UnaryOpcode::cbrt:
+          return { 0, datatree.AddCbrt(New_expr_t) };
         case UnaryOpcode::abs:
           return { 0, datatree.AddAbs(New_expr_t) };
         case UnaryOpcode::sign:
@@ -3222,6 +3244,8 @@ UnaryOpNode::buildSimilarUnaryOpNode(expr_t alt_arg, DataTree &alt_datatree) con
       return alt_datatree.AddAtanh(alt_arg);
     case UnaryOpcode::sqrt:
       return alt_datatree.AddSqrt(alt_arg);
+    case UnaryOpcode::cbrt:
+      return alt_datatree.AddCbrt(alt_arg);
     case UnaryOpcode::abs:
       return alt_datatree.AddAbs(alt_arg);
     case UnaryOpcode::sign:
@@ -3416,6 +3440,7 @@ UnaryOpNode::createAuxVarForUnaryOpNode() const
     case UnaryOpcode::asinh:
     case UnaryOpcode::atanh:
     case UnaryOpcode::sqrt:
+    case UnaryOpcode::cbrt:
     case UnaryOpcode::abs:
     case UnaryOpcode::sign:
     case UnaryOpcode::erf:
@@ -3627,6 +3652,9 @@ UnaryOpNode::substituteUnaryOpNodes(DataTree &static_datatree, diff_table_t &nod
     case UnaryOpcode::sqrt:
       unary_op = "sqrt";
       break;
+    case UnaryOpcode::cbrt:
+      unary_op = "cbrt";
+      break;
     case UnaryOpcode::abs:
       unary_op = "abs";
       break;
diff --git a/src/ParsingDriver.cc b/src/ParsingDriver.cc
index dab1623c..9aa88789 100644
--- a/src/ParsingDriver.cc
+++ b/src/ParsingDriver.cc
@@ -2827,6 +2827,12 @@ ParsingDriver::add_sqrt(expr_t arg1)
   return data_tree->AddSqrt(arg1);
 }
 
+expr_t
+ParsingDriver::add_cbrt(expr_t arg1)
+{
+  return data_tree->AddCbrt(arg1);
+}
+
 expr_t
 ParsingDriver::add_abs(expr_t arg1)
 {
diff --git a/src/ParsingDriver.hh b/src/ParsingDriver.hh
index 766da8dd..ed5eef0a 100644
--- a/src/ParsingDriver.hh
+++ b/src/ParsingDriver.hh
@@ -778,6 +778,8 @@ public:
   expr_t add_atanh(expr_t arg1);
   //! Writes token "sqrt(arg1)" to model tree
   expr_t add_sqrt(expr_t arg1);
+  //! Writes token "cbrt(arg1)" to model tree
+  expr_t add_cbrt(expr_t arg1);
   //! Writes token "abs(arg1)" to model tree
   expr_t add_abs(expr_t arg1);
   //! Writes token "sign(arg1)" to model tree
-- 
GitLab