From 7c65bceb0ef6298dfd1a1d42456fe705cfd54d93 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?S=C3=A9bastien=20Villemot?= <sebastien.villemot@ens.fr>
Date: Wed, 26 Jan 2011 13:55:01 -0500
Subject: [PATCH] Preprocessor: in steady_state_model block, allow MATLAB
 functions which return several arguments (closes #37)

---
 doc/manual.xml                   | 10 ++--
 preprocessor/DynareBison.yy      |  4 +-
 preprocessor/ExprNode.cc         |  2 +-
 preprocessor/ParsingDriver.cc    | 29 ++++++++++++
 preprocessor/ParsingDriver.hh    |  2 +
 preprocessor/SteadyStateModel.cc | 80 ++++++++++++++++++++++++--------
 preprocessor/SteadyStateModel.hh | 10 ++--
 preprocessor/SymbolList.hh       |  4 +-
 tests/.gitignore                 |  1 +
 tests/Makefile.am                |  3 +-
 tests/fs2000_ssfile.mod          |  5 +-
 tests/fs2000_ssfile_aux.m        |  4 ++
 12 files changed, 121 insertions(+), 33 deletions(-)
 create mode 100644 tests/fs2000_ssfile_aux.m

diff --git a/doc/manual.xml b/doc/manual.xml
index 1b2abef85..1b651a920 100644
--- a/doc/manual.xml
+++ b/doc/manual.xml
@@ -1961,7 +1961,10 @@ steady(homotopy_mode = 1, homotopy_steps = 50);
       <arg choice="plain">;</arg>
       <sbr/>
       <arg choice="plain" rep="repeat">
-        <replaceable>VARIABLE_NAME</replaceable> = <replaceable>EXPRESSION</replaceable> ;
+        <group>
+          <arg choice="plain"><replaceable>VARIABLE_NAME</replaceable> = <replaceable>EXPRESSION</replaceable> ;</arg>
+          <arg choice="plain">[ <replaceable>VARIABLE_NAME</replaceable>, <arg choice="plain" rep="repeat"><replaceable>VARIABLE_NAME</replaceable></arg> ] = <replaceable>EXPRESSION</replaceable> ;</arg>
+        </group>
       </arg>
       <sbr/>
       <command>end</command><arg choice="plain">;</arg>
@@ -2011,8 +2014,9 @@ steady_state_model;
   d  = l - mst + 1;
   y  = k^alp*n^(1-alp)*gst^alp;
   R  = mst/bet;
-  W  = l/n;
-  e = 1;
+
+  // You can use MATLAB functions which return several arguments
+  [W, e] = my_function(l, n);
   
   gp_obs = m/dA;
   gy_obs = dA;
diff --git a/preprocessor/DynareBison.yy b/preprocessor/DynareBison.yy
index 6d056c7c5..1a25176e9 100644
--- a/preprocessor/DynareBison.yy
+++ b/preprocessor/DynareBison.yy
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2010 Dynare Team
+ * Copyright (C) 2003-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -1651,6 +1651,8 @@ steady_state_equation_list : steady_state_equation_list steady_state_equation
 
 steady_state_equation : symbol EQUAL expression ';'
                         { driver.add_steady_state_model_equal($1, $3); }
+                      | '[' symbol_list ']' EQUAL expression ';'
+                        { driver.add_steady_state_model_equal_multiple($5); }
                       ;
 
 o_dr_algo : DR_ALGO EQUAL INT_NUMBER {
diff --git a/preprocessor/ExprNode.cc b/preprocessor/ExprNode.cc
index ec907ab99..21d75bcdf 100644
--- a/preprocessor/ExprNode.cc
+++ b/preprocessor/ExprNode.cc
@@ -4075,7 +4075,7 @@ ExternalFunctionNode::writeOutput(ostream &output, ExprNodeOutputType output_typ
                                   const temporary_terms_t &temporary_terms,
                                   deriv_node_temp_terms_t &tef_terms) const
 {
-  if (output_type == oMatlabOutsideModel)
+  if (output_type == oMatlabOutsideModel || output_type == oSteadyStateFile)
     {
       output << datatree.symbol_table.getName(symb_id) << "(";
       writeExternalFunctionArguments(output, output_type, temporary_terms, tef_terms);
diff --git a/preprocessor/ParsingDriver.cc b/preprocessor/ParsingDriver.cc
index 99a7e33d1..51e9db10b 100644
--- a/preprocessor/ParsingDriver.cc
+++ b/preprocessor/ParsingDriver.cc
@@ -1899,3 +1899,32 @@ ParsingDriver::add_steady_state_model_equal(string *varname, expr_t expr)
 
   delete varname;
 }
+
+void
+ParsingDriver::add_steady_state_model_equal_multiple(expr_t expr)
+{
+  const vector<string> &symbs = symbol_list.get_symbols();
+  vector<int> ids;
+
+  for (size_t i = 0; i < symbs.size(); i++)
+    {
+      int id;
+      try
+        {
+          id = mod_file->symbol_table.getID(symbs[i]);
+        }
+      catch (SymbolTable::UnknownSymbolNameException &e)
+        {
+          // Unknown symbol, declare it as a ModFileLocalVariable
+          id = mod_file->symbol_table.addSymbol(symbs[i], eModFileLocalVariable);
+        }
+      SymbolType type = mod_file->symbol_table.getType(id);
+      if (type != eEndogenous && type != eModFileLocalVariable && type != eParameter)
+        error(symbs[i] + " has incorrect type");
+      ids.push_back(id);
+    }
+
+  mod_file->steady_state_model.addMultipleDefinitions(ids, expr);
+
+  symbol_list.clear();
+}
diff --git a/preprocessor/ParsingDriver.hh b/preprocessor/ParsingDriver.hh
index f9b881a60..50ea88f52 100644
--- a/preprocessor/ParsingDriver.hh
+++ b/preprocessor/ParsingDriver.hh
@@ -496,6 +496,8 @@ public:
   void begin_steady_state_model();
   //! Add an assignment equation in steady_state_model block
   void add_steady_state_model_equal(string *varname, expr_t expr);
+  //! Add a multiple assignment equation in steady_state_model block
+  void add_steady_state_model_equal_multiple(expr_t expr);
   //! Switches datatree
   void begin_trend();
   //! Declares a trend variable with its growth factor
diff --git a/preprocessor/SteadyStateModel.cc b/preprocessor/SteadyStateModel.cc
index 4a916fb93..6caa166ed 100644
--- a/preprocessor/SteadyStateModel.cc
+++ b/preprocessor/SteadyStateModel.cc
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010 Dynare Team
+ * Copyright (C) 2010-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -30,44 +30,70 @@ SteadyStateModel::SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConst
 void
 SteadyStateModel::addDefinition(int symb_id, expr_t expr)
 {
+  AddVariable(symb_id); // Create the variable node to be used in write method
+
   assert(symbol_table.getType(symb_id) == eEndogenous
          || symbol_table.getType(symb_id) == eModFileLocalVariable
          || symbol_table.getType(symb_id) == eParameter);
 
   // Add the variable
-  recursive_order.push_back(symb_id);
-  def_table[symb_id] = AddEqual(AddVariable(symb_id), expr);
+  vector<int> v;
+  v.push_back(symb_id);
+  recursive_order.push_back(v);
+  def_table[v] = expr;
+}
+
+void
+SteadyStateModel::addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr)
+{
+  for (size_t i = 0; i < symb_ids.size(); i++)
+    {
+      AddVariable(symb_ids[i]); // Create the variable nodes to be used in write method
+      assert(symbol_table.getType(symb_ids[i]) == eEndogenous
+             || symbol_table.getType(symb_ids[i]) == eModFileLocalVariable
+             || symbol_table.getType(symb_ids[i]) == eParameter);
+    }
+  recursive_order.push_back(symb_ids);
+  def_table[symb_ids] = expr;
 }
 
 void
 SteadyStateModel::checkPass(bool ramsey_policy) const
 {
-  for (vector<int>::const_iterator it = recursive_order.begin();
-       it != recursive_order.end(); ++it)
+  vector<int> so_far_defined;
+
+  for (size_t i = 0; i < recursive_order.size(); i++)
     {
-      // Check that symbol is not already defined
-      if (find(recursive_order.begin(), it, *it) != it)
-        {
-          cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(*it) << "' is declared twice" << endl;
-          exit(EXIT_FAILURE);
-        }
+      const vector<int> &symb_ids = recursive_order[i];
+
+      // Check that symbols are not already defined
+      for (size_t j = 0; j < symb_ids.size(); j++)
+        if (find(so_far_defined.begin(), so_far_defined.end(), symb_ids[j])
+            != so_far_defined.end())
+          {
+            cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(symb_ids[j]) << "' is declared twice" << endl;
+            exit(EXIT_FAILURE);
+          }
 
       // Check that expression has no undefined symbol
       if (!ramsey_policy)
         {
           set<pair<int, int> > used_symbols;
-          expr_t expr = def_table.find(*it)->second;
+          expr_t expr = def_table.find(symb_ids)->second;
           expr->collectVariables(eEndogenous, used_symbols);
           expr->collectVariables(eModFileLocalVariable, used_symbols);
-          for(set<pair<int, int> >::const_iterator it2 = used_symbols.begin();
-              it2 != used_symbols.end(); ++it2)
-            if (find(recursive_order.begin(), it, it2->first) == it
-                && *it != it2->first)
+          for(set<pair<int, int> >::const_iterator it = used_symbols.begin();
+              it != used_symbols.end(); ++it)
+            if (find(so_far_defined.begin(), so_far_defined.end(), it->first)
+                == so_far_defined.end())
               {
-                cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it2->first) << "' is undefined in the declaration of variable '" << symbol_table.getName(*it) << "'" << endl;
+                cerr << "ERROR: in the 'steady_state' block, variable '" << symbol_table.getName(it->first)
+                     << "' is undefined in the declaration of variable '" << symbol_table.getName(symb_ids[0]) << "'" << endl;
                 exit(EXIT_FAILURE);
               }
         }
+
+      copy(symb_ids.begin(), symb_ids.end(), back_inserter(so_far_defined));
     }
 }
 
@@ -98,11 +124,25 @@ SteadyStateModel::writeSteadyStateFile(const string &basename, bool ramsey_polic
     output << "    ys_=zeros(" << symbol_table.orig_endo_nbr() << ",1);" << endl;
   output << "    global M_" << endl;
 
-  for(size_t i = 0; i < recursive_order.size(); i++)
+  for (size_t i = 0; i < recursive_order.size(); i++)
     {
+      const vector<int> &symb_ids = recursive_order[i];
       output << "    ";
-      map<int, expr_t>::const_iterator it = def_table.find(recursive_order[i]);
-      it->second->writeOutput(output, oSteadyStateFile);
+      if (symb_ids.size() > 1)
+        output << "[";
+      for (size_t j = 0; j < symb_ids.size(); j++)
+        {
+          variable_node_map_t::const_iterator it = variable_node_map.find(make_pair(symb_ids[j], 0));
+          assert(it != variable_node_map.end());
+          dynamic_cast<ExprNode *>(it->second)->writeOutput(output, oSteadyStateFile);
+          if (j < symb_ids.size()-1)
+            output << ",";
+        }
+      if (symb_ids.size() > 1)
+        output << "]";
+
+      output << "=";
+      def_table.find(symb_ids)->second->writeOutput(output, oSteadyStateFile);
       output << ";" << endl;
     }
   output << "    % Auxiliary equations" << endl;
diff --git a/preprocessor/SteadyStateModel.hh b/preprocessor/SteadyStateModel.hh
index d5677565f..cbb206afa 100644
--- a/preprocessor/SteadyStateModel.hh
+++ b/preprocessor/SteadyStateModel.hh
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2010 Dynare Team
+ * Copyright (C) 2010-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -26,9 +26,9 @@
 class SteadyStateModel : public DataTree
 {
 private:
-  //! Associates a symbol ID to an expression of the form "var = expr"
-  map<int, expr_t> def_table;
-  vector<int> recursive_order;
+  //! Associates a set of symbol IDs (the variable(s) assigned in a given statement) to an expression (their assigned value)
+  map<vector<int>, expr_t> def_table;
+  vector<vector<int> > recursive_order;
 
   //! Reference to static model (for writing auxiliary equations)
   const StaticModel &static_model;
@@ -37,6 +37,8 @@ public:
   SteadyStateModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants, ExternalFunctionsTable &external_functions_table_arg, const StaticModel &static_model_arg);
   //! Add an expression of the form "var = expr;"
   void addDefinition(int symb_id, expr_t expr);
+  //! Add an expression of the form "[ var1, var2, ... ] = expr;"
+  void addMultipleDefinitions(const vector<int> &symb_ids, expr_t expr);
   //! Checks that definitions are in a recursive order, and that no variable is declared twice
   /*!
     \param[in] ramsey_policy Is there a ramsey_policy statement in the MOD file? If yes, then disable the check on the recursivity of the declarations
diff --git a/preprocessor/SymbolList.hh b/preprocessor/SymbolList.hh
index e1ec9fac8..dafa965bf 100644
--- a/preprocessor/SymbolList.hh
+++ b/preprocessor/SymbolList.hh
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2003-2008 Dynare Team
+ * Copyright (C) 2003-2011 Dynare Team
  *
  * This file is part of Dynare.
  *
@@ -41,6 +41,8 @@ public:
   void writeOutput(const string &varname, ostream &output) const;
   //! Clears all content
   void clear();
+  //! Get a copy of the string vector
+  vector<string> get_symbols() const { return symbols; };
 };
 
 #endif
diff --git a/tests/.gitignore b/tests/.gitignore
index ecf434c22..88c873bc7 100644
--- a/tests/.gitignore
+++ b/tests/.gitignore
@@ -53,3 +53,4 @@
 !/run_test_octave.m
 !/swz/data.m
 !/test.m
+!/fs2000_ssfile_aux.m
diff --git a/tests/Makefile.am b/tests/Makefile.am
index 5cf4f45f8..bd3da9b3e 100644
--- a/tests/Makefile.am
+++ b/tests/Makefile.am
@@ -84,7 +84,8 @@ EXTRA_DIST = $(MODS) \
 	objectives \
 	ramst_initval_file_data.m \
 	homotopy/common.mod \
-	bvar_a_la_sims/bvar_sample.m
+	bvar_a_la_sims/bvar_sample.m \
+	fs2000_ssfile_aux.m
 
 TARGETS = check-matlab
 
diff --git a/tests/fs2000_ssfile.mod b/tests/fs2000_ssfile.mod
index eeb410e08..6e36877ba 100644
--- a/tests/fs2000_ssfile.mod
+++ b/tests/fs2000_ssfile.mod
@@ -53,8 +53,9 @@ steady_state_model;
   d  = l - mst + 1;
   y  = k^alp*n^(1-alp)*gst^alp;
   R  = mst/bet;
-  W  = l/n;
-  e = 1;
+
+  // Test function returning several arguments
+  [W, e] = fs2000_ssfile_aux(l, n);
   
   gp_obs = m/dA;
   gy_obs = dA;
diff --git a/tests/fs2000_ssfile_aux.m b/tests/fs2000_ssfile_aux.m
new file mode 100644
index 000000000..c16bfc96d
--- /dev/null
+++ b/tests/fs2000_ssfile_aux.m
@@ -0,0 +1,4 @@
+function [W, e] = fs2000_ssfile_aux(l, n)
+  W = l/n;
+  e = 1;
+end
-- 
GitLab