From a70b60604c441753aea474714c2b8a53892e4999 Mon Sep 17 00:00:00 2001
From: Houtan Bastani <houtanb@gmail.com>
Date: Wed, 3 Mar 2010 11:40:13 +0100
Subject: [PATCH] Modified external functions to enforce consistent number of
 function arguments within model_block

---
 DynareBison.yy            | 14 ++++-----
 ExprNode.cc               |  1 +
 ExternalFunctionsTable.cc | 60 ++++++++++++++++++++++-----------------
 ExternalFunctionsTable.hh | 14 ++++++---
 ModFile.cc                |  2 +-
 ParsingDriver.cc          | 27 ++++++++++++------
 ParsingDriver.hh          |  2 +-
 7 files changed, 72 insertions(+), 48 deletions(-)

diff --git a/DynareBison.yy b/DynareBison.yy
index 4eed5b49..2bdbf532 100644
--- a/DynareBison.yy
+++ b/DynareBison.yy
@@ -416,7 +416,7 @@ expression : '(' expression ')'
            | MIN '(' expression COMMA expression ')'
              { $$ = driver.add_min($3, $5); }
            | symbol { driver.push_external_function_arg_vector_onto_stack(); } '(' comma_expression ')'
-             { $$ = driver.add_model_var_or_external_function($1); }
+             { $$ = driver.add_model_var_or_external_function($1,false); }
            | NORMCDF '(' expression COMMA expression COMMA expression ')'
              { $$ = driver.add_normcdf($3, $5, $7); }
            | NORMCDF '(' expression ')'
@@ -562,17 +562,17 @@ hand_side : '(' hand_side ')'
           | SQRT '(' hand_side ')'
             { $$ = driver.add_sqrt($3); }
           | MAX '(' hand_side COMMA hand_side ')'
-             { $$ = driver.add_max($3, $5); }
+            { $$ = driver.add_max($3, $5); }
           | MIN '(' hand_side COMMA hand_side ')'
-             { $$ = driver.add_min($3, $5); }
+            { $$ = driver.add_min($3, $5); }
           | symbol { driver.push_external_function_arg_vector_onto_stack(); } '(' comma_hand_side ')'
-            { $$ = driver.add_model_var_or_external_function($1); }
+            { $$ = driver.add_model_var_or_external_function($1,true); }
           | NORMCDF '(' hand_side COMMA hand_side COMMA hand_side ')'
-             { $$ = driver.add_normcdf($3, $5, $7); }
+            { $$ = driver.add_normcdf($3, $5, $7); }
           | NORMCDF '(' hand_side ')'
-             { $$ = driver.add_normcdf($3); }
+            { $$ = driver.add_normcdf($3); }
           | STEADY_STATE '(' hand_side ')'
-             { $$ = driver.add_steady_state($3); }
+            { $$ = driver.add_steady_state($3); }
           ;
 
 comma_hand_side : hand_side
diff --git a/ExprNode.cc b/ExprNode.cc
index 38496423..6965ebd8 100644
--- a/ExprNode.cc
+++ b/ExprNode.cc
@@ -3206,6 +3206,7 @@ ExternalFunctionNode::prepareForDerivation()
 NodeID
 ExternalFunctionNode::computeDerivative(int deriv_id)
 {
+  assert(datatree.external_functions_table.getNargs(symb_id) > 0);
   vector<NodeID> dargs;
   for (vector<NodeID>::const_iterator it = arguments.begin(); it != arguments.end(); it++)
     dargs.push_back((*it)->getDerivative(deriv_id));
diff --git a/ExternalFunctionsTable.cc b/ExternalFunctionsTable.cc
index 83c03f60..2650a977 100644
--- a/ExternalFunctionsTable.cc
+++ b/ExternalFunctionsTable.cc
@@ -31,16 +31,12 @@ ExternalFunctionsTable::ExternalFunctionsTable()
 }
 
 void
-ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function_options &external_function_options_arg)
+ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function_options &external_function_options_arg, bool track_nargs)
 {
   assert(symb_id >= 0);
+  assert(external_function_options_arg.nargs > 0);
 
-  if (external_function_options_arg.nargs <= 0)
-    {
-      cerr << "ERROR: The number of arguments passed to an external function must be > 0." << endl;
-      exit(EXIT_FAILURE);
-    }
-
+  // Change options to be saved so the table is consistent
   external_function_options external_function_options_chng = external_function_options_arg;
   if (external_function_options_arg.firstDerivSymbID  == eExtFunSetButNoNameProvided)
     external_function_options_chng.firstDerivSymbID = symb_id;
@@ -48,6 +44,10 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
   if (external_function_options_arg.secondDerivSymbID == eExtFunSetButNoNameProvided)
     external_function_options_chng.secondDerivSymbID = symb_id;
 
+  if (!track_nargs)
+    external_function_options_chng.nargs = eExtFunNotSet;
+
+  // Ensure 1st & 2nd deriv option consistency
   if (external_function_options_chng.secondDerivSymbID == symb_id &&
       external_function_options_chng.firstDerivSymbID  != symb_id)
     {
@@ -81,29 +81,37 @@ ExternalFunctionsTable::addExternalFunction(int symb_id, const external_function
       exit(EXIT_FAILURE);
     }
 
+  // Ensure that if we're overwriting something, we mean to do it
   if (exists(symb_id))
     {
-      if (external_function_options_arg.nargs != getNargs(symb_id))
-        {
-          cerr << "ERROR: The number of arguments passed to the external_function() statement do not "
-               << "match the number of arguments passed to a previous call or declaration of the top-level function."<< endl;
-          exit(EXIT_FAILURE);
-        }
+      bool ok_to_overwrite = false;
+      if (getNargs(symb_id) == eExtFunNotSet) // implies that the information stored about this function is not important
+        ok_to_overwrite = true;
 
-      if (external_function_options_chng.firstDerivSymbID != getFirstDerivSymbID(symb_id))
-        {
-          cerr << "ERROR: The first derivative function passed to the external_function() statement does not "
-               << "match the first derivative function passed to a previous call or declaration of the top-level function."<< endl;
-          exit(EXIT_FAILURE);
-        }
+      if (!ok_to_overwrite) // prevents multiple non-compatible calls to external_function(name=funcname)
+        {                   // e.g. e_f(name=a,nargs=1,fd,sd) and e_f(name=a,nargs=2,fd=b,sd=c) should cause an error
+         if (external_function_options_chng.nargs != getNargs(symb_id))
+            {
+              cerr << "ERROR: The number of arguments passed to the external_function() statement do not "
+                   << "match the number of arguments passed to a previous call or declaration of the top-level function."<< endl;
+              exit(EXIT_FAILURE);
+            }
 
-      if (external_function_options_chng.secondDerivSymbID != getSecondDerivSymbID(symb_id))
-        {
-          cerr << "ERROR: The second derivative function passed to the external_function() statement does not "
-               << "match the second derivative function passed to a previous call or declaration of the top-level function."<< endl;
-          exit(EXIT_FAILURE);
+          if (external_function_options_chng.firstDerivSymbID != getFirstDerivSymbID(symb_id))
+            {
+              cerr << "ERROR: The first derivative function passed to the external_function() statement does not "
+                   << "match the first derivative function passed to a previous call or declaration of the top-level function."<< endl;
+              exit(EXIT_FAILURE);
+            }
+
+          if (external_function_options_chng.secondDerivSymbID != getSecondDerivSymbID(symb_id))
+            {
+              cerr << "ERROR: The second derivative function passed to the external_function() statement does not "
+                   << "match the second derivative function passed to a previous call or declaration of the top-level function."<< endl;
+              exit(EXIT_FAILURE);
+            }
         }
     }
-  else
-    externalFunctionTable[symb_id] = external_function_options_chng;
+
+  externalFunctionTable[symb_id] = external_function_options_chng;
 }
diff --git a/ExternalFunctionsTable.hh b/ExternalFunctionsTable.hh
index e1d19cf9..85e97e40 100644
--- a/ExternalFunctionsTable.hh
+++ b/ExternalFunctionsTable.hh
@@ -65,7 +65,7 @@ private:
 public:
   ExternalFunctionsTable();
   //! Adds an external function to the table as well as its derivative functions
-  void addExternalFunction(int symb_id, const external_function_options &external_function_options_arg);
+  void addExternalFunction(int symb_id, const external_function_options &external_function_options_arg, bool track_nargs);
   //! See if the function exists in the External Functions Table
   inline bool exists(int symb_id) const;
   //! Get the number of arguments for a given external function
@@ -75,7 +75,7 @@ public:
   //! Get the symbol_id of the second derivative function
   inline int getSecondDerivSymbID(int symb_id) const throw (UnknownExternalFunctionSymbolIDException);
   //! Returns the total number of unique external functions declared or used in the .mod file
-  inline int get_total_number_of_unique_external_functions() const;
+  inline int get_total_number_of_unique_model_block_external_functions() const;
 };
 
 inline bool
@@ -113,9 +113,15 @@ ExternalFunctionsTable::getSecondDerivSymbID(int symb_id) const throw (UnknownEx
 }
 
 inline int
-ExternalFunctionsTable::get_total_number_of_unique_external_functions() const
+ExternalFunctionsTable::get_total_number_of_unique_model_block_external_functions() const
 {
-  return externalFunctionTable.size();
+  int number_of_unique_model_block_external_functions = 0;
+  for (external_function_table_type::const_iterator it = externalFunctionTable.begin();
+       it != externalFunctionTable.end(); it++)
+    if (it->second.nargs > 0)
+      number_of_unique_model_block_external_functions++;
+
+  return number_of_unique_model_block_external_functions;
 }
 
 #endif
diff --git a/ModFile.cc b/ModFile.cc
index 05ffae0b..409a1cd5 100644
--- a/ModFile.cc
+++ b/ModFile.cc
@@ -135,7 +135,7 @@ ModFile::checkPass()
       exit(EXIT_FAILURE);
     }
 
-  if ((use_dll || byte_code) && external_functions_table.get_total_number_of_unique_external_functions())
+  if ((use_dll || byte_code) && (external_functions_table.get_total_number_of_unique_model_block_external_functions() > 0))
     {
       cerr << "ERROR: In 'model' block, use of external functions is not compatible with 'use_dll' or 'bytecode'" << endl;
       exit(EXIT_FAILURE);
diff --git a/ParsingDriver.cc b/ParsingDriver.cc
index 0f890e22..d240abbf 100644
--- a/ParsingDriver.cc
+++ b/ParsingDriver.cc
@@ -1628,6 +1628,7 @@ void
 ParsingDriver::external_function_option(const string &name_option, string *opt)
 {
   external_function_option(name_option, *opt);
+  delete opt;
 }
 
 void
@@ -1680,7 +1681,7 @@ ParsingDriver::external_function()
       current_external_function_options.firstDerivSymbID  != eExtFunSetButNoNameProvided)
     error("If the second derivative is provided in the top-level function, the first derivative must also be provided in that function.");
 
-  mod_file->external_functions_table.addExternalFunction(current_external_function_id, current_external_function_options);
+  mod_file->external_functions_table.addExternalFunction(current_external_function_id, current_external_function_options, true);
   reset_current_external_function_options();
 }
 
@@ -1698,8 +1699,9 @@ ParsingDriver::add_external_function_arg(NodeID arg)
 }
 
 NodeID
-ParsingDriver::add_model_var_or_external_function(string *function_name)
+ParsingDriver::add_model_var_or_external_function(string *function_name, bool in_model_block)
 {
+  NodeID nid;
   if (mod_file->symbol_table.exists(*function_name))
     {
       if (mod_file->symbol_table.getType(*function_name) != eExternalFunction)
@@ -1734,7 +1736,7 @@ ParsingDriver::add_model_var_or_external_function(string *function_name)
               if ((double) model_var_arg != model_var_arg_dbl) //make 100% sure int cast didn't lose info
                 error("A model variable is being treated as if it were a function (i.e., takes an argument that is not an integer).");
 
-              NodeID nid = add_model_variable(mod_file->symbol_table.getID(*function_name), model_var_arg);
+              nid = add_model_variable(mod_file->symbol_table.getID(*function_name), model_var_arg);
               stack_external_function_args.pop();
               delete function_name;
               return nid;
@@ -1748,25 +1750,32 @@ ParsingDriver::add_model_var_or_external_function(string *function_name)
           int symb_id = mod_file->symbol_table.getID(*function_name);
           assert(mod_file->external_functions_table.exists(symb_id));
 
-          if ((int)(stack_external_function_args.top().size()) != mod_file->external_functions_table.getNargs(symb_id))
-            error("The number of arguments passed to " + *function_name +
-                  " does not match those of a previous call or declaration of this function.");
+          if (in_model_block)
+            if (mod_file->external_functions_table.getNargs(symb_id) == eExtFunNotSet)
+              error("Before using " + *function_name +
+                    "() in the model block, you must first declare it via the external_function() statement");
+            else if ((int)(stack_external_function_args.top().size()) != mod_file->external_functions_table.getNargs(symb_id))
+              error("The number of arguments passed to " + *function_name +
+                    "() does not match those of a previous call or declaration of this function.");
         }
     }
   else
     { //First time encountering this external function i.e., not previously declared or encountered
+      if (in_model_block)
+        error("To use an external function within the model block, you must first declare it via the external_function() statement.");
+
       declare_symbol(function_name, eExternalFunction, NULL);
       current_external_function_options.nargs = stack_external_function_args.top().size();
       mod_file->external_functions_table.addExternalFunction(mod_file->symbol_table.getID(*function_name),
-                                                             current_external_function_options);
+                                                             current_external_function_options, in_model_block);
       reset_current_external_function_options();
     }
 
   //By this point, we're sure that this function exists in the External Functions Table and is not a mod var
-  NodeID id = data_tree->AddExternalFunction(*function_name, stack_external_function_args.top());
+  nid = data_tree->AddExternalFunction(*function_name, stack_external_function_args.top());
   stack_external_function_args.pop();
   delete function_name;
-  return id;
+  return nid;
 }
 
 void
diff --git a/ParsingDriver.hh b/ParsingDriver.hh
index 55d26602..ce9f1bd2 100644
--- a/ParsingDriver.hh
+++ b/ParsingDriver.hh
@@ -484,7 +484,7 @@ public:
   //! Adds an external function argument
   void add_external_function_arg(NodeID arg);
   //! Adds an external function call node
-  NodeID add_model_var_or_external_function(string *function_name);
+  NodeID add_model_var_or_external_function(string *function_name, bool in_model_block);
   //! Adds a native statement
   void add_native(const char *s);
   //! Resets data_tree and model_tree pointers to default (i.e. mod_file->expressions_tree)
-- 
GitLab