From 9b3d611a0b92e215bfd2ad39c4b1938c1c6285e1 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien.villemot@ens.fr>
Date: Wed, 26 Jan 2011 13:55:01 -0500
Subject: [PATCH] Preprocessor: in steady_state_model block, allow MATLAB
 functions which return several arguments (closes #37)

---
 DynareBison.yy      |  4 ++-
 ExprNode.cc         |  2 +-
 ParsingDriver.cc    | 29 ++++++++++++++++
 ParsingDriver.hh    |  2 ++
 SteadyStateModel.cc | 80 +++++++++++++++++++++++++++++++++------------
 SteadyStateModel.hh | 10 +++---
 SymbolList.hh       |  4 ++-
 7 files changed, 104 insertions(+), 27 deletions(-)

diff --git a/DynareBison.yy b/DynareBison.yy
index 6d056c7c..1a25176e 100644
--- a/DynareBison.yy
+++ b/DynareBison.yy
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2010 Dynare Team
+ * Copyright (C) 2003-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -1651,6 +1651,8 @@ steady_state_equation_list : steady_state_equation_list steady_state_equation
 
 steady_state_equation : symbol EQUAL expression ';'
                         { driver.add_steady_state_model_equal($1, $3); }
+                      | '[' symbol_list ']' EQUAL expression ';'
+                        { driver.add_steady_state_model_equal_multiple($5); }
                       ;
 
 o_dr_algo : DR_ALGO EQUAL INT_NUMBER {
diff --git a/ExprNode.cc b/ExprNode.cc
index ec907ab9..21d75bcd 100644
--- a/ExprNode.cc
+++ b/ExprNode.cc
@@ -4075,7 +4075,7 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ
                                   const temporary_terms_t &temporary_terms,
                                   deriv_node_temp_terms_t &tef_terms) const
 {
-  if (output_type == oMatlabOutsideModel)
+  if (output_type == oMatlabOutsideModel || output_type == oSteadyStateFile)
     {
       output << datatree.symbol_table.getName(symb_id) << "(";
       writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
diff --git a/ParsingDriver.cc b/ParsingDriver.cc
index 99a7e33d..51e9db10 100644
--- a/ParsingDriver.cc
+++ b/ParsingDriver.cc
@@ -1899,3 +1899,32 @@ ParsingDriver::add_steady_state_model_equal(string *varname, expr_t expr)
 
   delete varname;
 }
+
+void
+ParsingDriver::add_steady_state_model_equal_multiple(expr_t expr)
+{
+  const vector<string> &symbs = symbol_list.get_symbols();
+  vector<int> ids;
+
+  for (size_t i = 0; i < symbs.size(); i++)
+    {
+      int id;
+      try
+        {
+          id = mod_file->symbol_table.getID(symbs[i]);
+        }
+      catch (SymbolTable::UnknownSymbolNameException &e)
+        {
+          // Unknown symbol, declare it as a ModFileLocalVariable
+          id = mod_file->symbol_table.addSymbol(symbs[i], eModFileLocalVariable);
+        }
+      SymbolType type = mod_file->symbol_table.getType(id);
+      if (type != eEndogenous && type != eModFileLocalVariable && type != eParameter)
+        error(symbs[i] + " has incorrect type");
+      ids.push_back(id);
+    }
+
+  mod_file->steady_state_model.addMultipleDefinitions(ids, expr);
+
+  symbol_list.clear();
+}
diff --git a/ParsingDriver.hh b/ParsingDriver.hh
index f9b881a6..50ea88f5 100644
--- a/ParsingDriver.hh
+++ b/ParsingDriver.hh
@@ -496,6 +496,8 @@ public:
   void begin_steady_state_model();
   //! Add an assignment equation in steady_state_model block
   void add_steady_state_model_equal(string *varname, expr_t expr);
+  //! Add a multiple assignment equation in steady_state_model block
+  void add_steady_state_model_equal_multiple(expr_t expr);
   //! Switches datatree
   void begin_trend();
   //! Declares a trend variable with its growth factor
diff --git a/SteadyStateModel.cc b/SteadyStateModel.cc
index 4a916fb9..6caa166e 100644
--- a/SteadyStateModel.cc
+++ b/SteadyStateModel.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010 Dynare Team
+ * Copyright (C) 2010-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -30,44 +30,70 @@ SteadyStateModel::SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConst
 void
 SteadyStateModel::addDefinition(int symb_id, expr_t expr)
 {
+  AddVariable(symb_id); // Create the variable node to be used in write method
+
   assert(symbol_table.getType(symb_id) == eEndogenous
          || symbol_table.getType(symb_id) == eModFileLocalVariable
          || symbol_table.getType(symb_id) == eParameter);
 
   // Add the variable
-  recursive_order.push_back(symb_id);
-  def_table[symb_id] = AddEqual(AddVariable(symb_id), expr);
+  vector<int> v;
+  v.push_back(symb_id);
+  recursive_order.push_back(v);
+  def_table[v] = expr;
+}
+
+void
+SteadyStateModel::addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr)
+{
+  for (size_t i = 0; i < symb_ids.size(); i++)
+    {
+      AddVariable(symb_ids[i]); // Create the variable nodes to be used in write method
+      assert(symbol_table.getType(symb_ids[i]) == eEndogenous
+             || symbol_table.getType(symb_ids[i]) == eModFileLocalVariable
+             || symbol_table.getType(symb_ids[i]) == eParameter);
+    }
+  recursive_order.push_back(symb_ids);
+  def_table[symb_ids] = expr;
 }
 
 void
 SteadyStateModel::checkPass(bool ramsey_policy) const
 {
-  for (vector<int>::const_iterator it = recursive_order.begin();
-       it != recursive_order.end(); ++it)
+  vector<int> so_far_defined;
+
+  for (size_t i = 0; i < recursive_order.size(); i++)
     {
-      // Check that symbol is not already defined
-      if (find(recursive_order.begin(), it, *it) != it)
-        {
-          cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(*it) << "' is declared twice" << endl;
-          exit(EXIT_FAILURE);
-        }
+      const vector<int> &symb_ids = recursive_order[i];
+
+      // Check that symbols are not already defined
+      for (size_t j = 0; j < symb_ids.size(); j++)
+        if (find(so_far_defined.begin(), so_far_defined.end(), symb_ids[j])
+            != so_far_defined.end())
+          {
+            cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(symb_ids[j]) << "' is declared twice" << endl;
+            exit(EXIT_FAILURE);
+          }
 
       // Check that expression has no undefined symbol
       if (!ramsey_policy)
         {
           set<pair<int, int> > used_symbols;
-          expr_t expr = def_table.find(*it)->second;
+          expr_t expr = def_table.find(symb_ids)->second;
           expr->collectVariables(eEndogenous, used_symbols);
           expr->collectVariables(eModFileLocalVariable, used_symbols);
-          for(set<pair<int, int> >::const_iterator it2 = used_symbols.begin();
-              it2 != used_symbols.end(); ++it2)
-            if (find(recursive_order.begin(), it, it2->first) == it
-                && *it != it2->first)
+          for(set<pair<int, int> >::const_iterator it = used_symbols.begin();
+              it != used_symbols.end(); ++it)
+            if (find(so_far_defined.begin(), so_far_defined.end(), it->first)
+                == so_far_defined.end())
               {
-                cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it2->first) << "' is undefined in the declaration of variable '" << symbol_table.getName(*it) << "'" << endl;
+                cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it->first)
+                     << "' is undefined in the declaration of variable '" << symbol_table.getName(symb_ids[0]) << "'" << endl;
                 exit(EXIT_FAILURE);
               }
         }
+
+      copy(symb_ids.begin(), symb_ids.end(), back_inserter(so_far_defined));
     }
 }
 
@@ -98,11 +124,25 @@ SteadyStateModel::writeSteadyStateFile(const string &basename, bool ramsey_polic
     output << "    ys_=zeros(" << symbol_table.orig_endo_nbr() << ",1);" << endl;
   output << "    global M_" << endl;
 
-  for(size_t i = 0; i < recursive_order.size(); i++)
+  for (size_t i = 0; i < recursive_order.size(); i++)
     {
+      const vector<int> &symb_ids = recursive_order[i];
       output << "    ";
-      map<int, expr_t>::const_iterator it = def_table.find(recursive_order[i]);
-      it->second->writeOutput(output, oSteadyStateFile);
+      if (symb_ids.size() > 1)
+        output << "[";
+      for (size_t j = 0; j < symb_ids.size(); j++)
+        {
+          variable_node_map_t::const_iterator it = variable_node_map.find(make_pair(symb_ids[j], 0));
+          assert(it != variable_node_map.end());
+          dynamic_cast<ExprNode *>(it->second)->writeOutput(output, oSteadyStateFile);
+          if (j < symb_ids.size()-1)
+            output << ",";
+        }
+      if (symb_ids.size() > 1)
+        output << "]";
+
+      output << "=";
+      def_table.find(symb_ids)->second->writeOutput(output, oSteadyStateFile);
       output << ";" << endl;
     }
   output << "    % Auxiliary equations" << endl;
diff --git a/SteadyStateModel.hh b/SteadyStateModel.hh
index d5677565..cbb206af 100644
--- a/SteadyStateModel.hh
+++ b/SteadyStateModel.hh
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010 Dynare Team
+ * Copyright (C) 2010-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -26,9 +26,9 @@
 class SteadyStateModel : public DataTree
 {
 private:
-  //! Associates a symbol ID to an expression of the form "var = expr"
-  map<int, expr_t> def_table;
-  vector<int> recursive_order;
+  //! Associates a set of symbol IDs (the variable(s) assigned in a given statement) to an expression (their assigned value)
+  map<vector<int>, expr_t> def_table;
+  vector<vector<int> > recursive_order;
 
   //! Reference to static model (for writing auxiliary equations)
   const StaticModel &static_model;
@@ -37,6 +37,8 @@ public:
   SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants, ExternalFunctionsTable &external_functions_table_arg, const StaticModel &static_model_arg);
   //! Add an expression of the form "var = expr;"
   void addDefinition(int symb_id, expr_t expr);
+  //! Add an expression of the form "[ var1, var2, ... ] = expr;"
+  void addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr);
   //! Checks that definitions are in a recursive order, and that no variable is declared twice
   /*!
     \param[in] ramsey_policy Is there a ramsey_policy statement in the MOD file? If yes, then disable the check on the recursivity of the declarations
diff --git a/SymbolList.hh b/SymbolList.hh
index e1ec9fac..dafa965b 100644
--- a/SymbolList.hh
+++ b/SymbolList.hh
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2008 Dynare Team
+ * Copyright (C) 2003-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -41,6 +41,8 @@ public:
   void writeOutput(const string &varname, ostream &output) const;
   //! Clears all content
   void clear();
+  //! Get a copy of the string vector
+  vector<string> get_symbols() const { return symbols; };
 };
 
 #endif
-- 
GitLab