Commit 51378821 authored by sebastien's avatar sebastien
Browse files

trunk preprocessor:

* enforce lag=0 for DataTree and StaticModel
* various minor and cosmetic changes


git-svn-id: https://www.dynare.org/svn/dynare/trunk@2596 ac1d8469-bf42-47a9-8791-bf33cf982152
parent 7afeae2f
......@@ -30,12 +30,15 @@ DataTree::DataTree(SymbolTable &symbol_table_arg, NumericalConstants &num_consta
{
Zero = AddNumConstant("0");
One = AddNumConstant("1");
Two = AddNumConstant("2");
MinusOne = AddUMinus(One);
NaN = AddNumConstant("NaN");
Infinity = AddNumConstant("Inf");
MinusInfinity = AddUMinus(Infinity);
Pi = AddNumConstant("3.141592653589793");
}
DataTree::~DataTree()
......@@ -57,7 +60,7 @@ DataTree::AddNumConstant(const string &value)
}
NodeID
DataTree::AddVariable(const string &name, int lag)
DataTree::AddVariableInternal(const string &name, int lag)
{
int symb_id = symbol_table.getID(name);
......@@ -68,6 +71,17 @@ DataTree::AddVariable(const string &name, int lag)
return new VariableNode(*this, symb_id, lag);
}
NodeID
DataTree::AddVariable(const string &name, int lag)
{
if (lag != 0)
{
cerr << "DataTree::AddVariable: a non-zero lag is forbidden here!" << endl;
exit(EXIT_FAILURE);
}
return AddVariableInternal(name, lag);
}
NodeID
DataTree::AddPlus(NodeID iArg1, NodeID iArg2)
{
......@@ -283,7 +297,7 @@ DataTree::AddTan(NodeID iArg1)
}
NodeID
DataTree::AddACos(NodeID iArg1)
DataTree::AddAcos(NodeID iArg1)
{
if (iArg1 != One)
return AddUnaryOp(oAcos, iArg1);
......@@ -292,7 +306,7 @@ DataTree::AddACos(NodeID iArg1)
}
NodeID
DataTree::AddASin(NodeID iArg1)
DataTree::AddAsin(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oAsin, iArg1);
......@@ -301,7 +315,7 @@ DataTree::AddASin(NodeID iArg1)
}
NodeID
DataTree::AddATan(NodeID iArg1)
DataTree::AddAtan(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oAtan, iArg1);
......@@ -310,7 +324,7 @@ DataTree::AddATan(NodeID iArg1)
}
NodeID
DataTree::AddCosH(NodeID iArg1)
DataTree::AddCosh(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oCosh, iArg1);
......@@ -319,7 +333,7 @@ DataTree::AddCosH(NodeID iArg1)
}
NodeID
DataTree::AddSinH(NodeID iArg1)
DataTree::AddSinh(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oSinh, iArg1);
......@@ -328,7 +342,7 @@ DataTree::AddSinH(NodeID iArg1)
}
NodeID
DataTree::AddTanH(NodeID iArg1)
DataTree::AddTanh(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oTanh, iArg1);
......@@ -337,7 +351,7 @@ DataTree::AddTanH(NodeID iArg1)
}
NodeID
DataTree::AddACosH(NodeID iArg1)
DataTree::AddAcosh(NodeID iArg1)
{
if (iArg1 != One)
return AddUnaryOp(oAcosh, iArg1);
......@@ -346,7 +360,7 @@ DataTree::AddACosH(NodeID iArg1)
}
NodeID
DataTree::AddASinH(NodeID iArg1)
DataTree::AddAsinh(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oAsinh, iArg1);
......@@ -355,7 +369,7 @@ DataTree::AddASinH(NodeID iArg1)
}
NodeID
DataTree::AddATanH(NodeID iArg1)
DataTree::AddAtanh(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oAtanh, iArg1);
......@@ -364,7 +378,7 @@ DataTree::AddATanH(NodeID iArg1)
}
NodeID
DataTree::AddSqRt(NodeID iArg1)
DataTree::AddSqrt(NodeID iArg1)
{
if (iArg1 != Zero)
return AddUnaryOp(oSqrt, iArg1);
......@@ -373,7 +387,7 @@ DataTree::AddSqRt(NodeID iArg1)
}
NodeID
DataTree::AddMaX(NodeID iArg1, NodeID iArg2)
DataTree::AddMax(NodeID iArg1, NodeID iArg2)
{
return AddBinaryOp(iArg1, oMax, iArg2);
}
......@@ -397,14 +411,20 @@ DataTree::AddEqual(NodeID iArg1, NodeID iArg2)
}
void
DataTree::AddLocalParameter(const string &name, NodeID value) throw (LocalParameterException)
DataTree::AddLocalVariable(const string &name, NodeID value) throw (LocalVariableException)
{
int id = symbol_table.getID(name);
if (symbol_table.getType(id) != eModelLocalVariable)
{
cerr << "Symbol " << name << " is not a model local variable!" << endl;
exit(EXIT_FAILURE);
}
// Throw an exception if symbol already declared
map<int, NodeID>::iterator it = local_variables_table.find(id);
if (it != local_variables_table.end())
throw LocalParameterException(name);
throw LocalVariableException(name);
local_variables_table[id] = value;
}
......@@ -412,14 +432,14 @@ DataTree::AddLocalParameter(const string &name, NodeID value) throw (LocalParame
NodeID
DataTree::AddUnknownFunction(const string &function_name, const vector<NodeID> &arguments)
{
if (symbol_table.getType(function_name) != eUnknownFunction)
int id = symbol_table.getID(function_name);
if (symbol_table.getType(id) != eUnknownFunction)
{
cerr << "Symbol " << function_name << " is not a function name!";
cerr << "Symbol " << function_name << " is not a function name!" << endl;
exit(EXIT_FAILURE);
}
int id = symbol_table.getID(function_name);
return new UnknownFunctionNode(*this, id, arguments);
}
......
......@@ -50,13 +50,6 @@ protected:
//! Reference to numerical constants table
NumericalConstants &num_constants;
typedef list<NodeID> node_list_type;
//! The list of nodes
node_list_type node_list;
//! A counter for filling ExprNode's idx field
int node_counter;
typedef map<int, NodeID> num_const_node_map_type;
num_const_node_map_type num_const_node_map;
//! Pair (symbol_id, lag) used as key
......@@ -69,31 +62,45 @@ protected:
typedef map<pair<pair<pair<NodeID, NodeID>,NodeID>, int>, NodeID> trinary_op_node_map_type;
trinary_op_node_map_type trinary_op_node_map;
//! Stores local variables value (maps symbol ID to corresponding node)
map<int, NodeID> local_variables_table;
//! Internal implementation of AddVariable(), without the check on the lag
NodeID AddVariableInternal(const string &name, int lag);
private:
typedef list<NodeID> node_list_type;
//! The list of nodes
node_list_type node_list;
//! A counter for filling ExprNode's idx field
int node_counter;
inline NodeID AddPossiblyNegativeConstant(double val);
inline NodeID AddUnaryOp(UnaryOpcode op_code, NodeID arg);
inline NodeID AddBinaryOp(NodeID arg1, BinaryOpcode op_code, NodeID arg2);
inline NodeID AddTrinaryOp(NodeID arg1, TrinaryOpcode op_code, NodeID arg2, NodeID arg3);
//! Stores local variables value (maps symbol ID to corresponding node)
map<int, NodeID> local_variables_table;
public:
DataTree(SymbolTable &symbol_table_arg, NumericalConstants &num_constants_arg);
virtual ~DataTree();
//! The variable table
VariableTable variable_table;
NodeID Zero, One, MinusOne, NaN, Infinity, MinusInfinity;
//! Some predefined constants
NodeID Zero, One, Two, MinusOne, NaN, Infinity, MinusInfinity, Pi;
//! Raised when a local parameter is declared twice
class LocalParameterException
class LocalVariableException
{
public:
string name;
LocalParameterException(const string &name_arg) : name(name_arg) {}
LocalVariableException(const string &name_arg) : name(name_arg) {}
};
//! Adds a numerical constant
NodeID AddNumConstant(const string &value);
NodeID AddVariable(const string &name, int lag = 0);
//! Adds a variable
/*! The default implementation of the method refuses any lag != 0 */
virtual NodeID AddVariable(const string &name, int lag = 0);
//! Adds "arg1+arg2" to model tree
NodeID AddPlus(NodeID iArg1, NodeID iArg2);
//! Adds "arg1-arg2" to model tree
......@@ -131,34 +138,35 @@ public:
//! Adds "tan(arg)" to model tree
NodeID AddTan(NodeID iArg1);
//! Adds "acos(arg)" to model tree
NodeID AddACos(NodeID iArg1);
NodeID AddAcos(NodeID iArg1);
//! Adds "asin(arg)" to model tree
NodeID AddASin(NodeID iArg1);
NodeID AddAsin(NodeID iArg1);
//! Adds "atan(arg)" to model tree
NodeID AddATan(NodeID iArg1);
NodeID AddAtan(NodeID iArg1);
//! Adds "cosh(arg)" to model tree
NodeID AddCosH(NodeID iArg1);
NodeID AddCosh(NodeID iArg1);
//! Adds "sinh(arg)" to model tree
NodeID AddSinH(NodeID iArg1);
NodeID AddSinh(NodeID iArg1);
//! Adds "tanh(arg)" to model tree
NodeID AddTanH(NodeID iArg1);
NodeID AddTanh(NodeID iArg1);
//! Adds "acosh(arg)" to model tree
NodeID AddACosH(NodeID iArg1);
NodeID AddAcosh(NodeID iArg1);
//! Adds "asinh(arg)" to model tree
NodeID AddASinH(NodeID iArg1);
NodeID AddAsinh(NodeID iArg1);
//! Adds "atanh(args)" to model tree
NodeID AddATanH(NodeID iArg1);
NodeID AddAtanh(NodeID iArg1);
//! Adds "sqrt(arg)" to model tree
NodeID AddSqRt(NodeID iArg1);
NodeID AddSqrt(NodeID iArg1);
//! Adds "max(arg1,arg2)" to model tree
NodeID AddMaX(NodeID iArg1, NodeID iArg2);
NodeID AddMax(NodeID iArg1, NodeID iArg2);
//! Adds "min(arg1,arg2)" to model tree
NodeID AddMin(NodeID iArg1, NodeID iArg2);
//! Adds "normcdf(arg1,arg2,arg3)" to model tree
NodeID AddNormcdf(NodeID iArg1, NodeID iArg2, NodeID iArg3);
//! Adds "arg1=arg2" to model tree
NodeID AddEqual(NodeID iArg1, NodeID iArg2);
void AddLocalParameter(const string &name, NodeID value) throw (LocalParameterException);
//! Adds a model local variable with its value
void AddLocalVariable(const string &name, NodeID value) throw (LocalVariableException);
//! Adds an unknown function node
/*! \todo Use a map to share identical nodes */
NodeID AddUnknownFunction(const string &function_name, const vector<NodeID> &arguments);
......
......@@ -44,6 +44,12 @@ DynamicModel::DynamicModel(SymbolTable &symbol_table_arg,
{
}
NodeID
DynamicModel::AddVariable(const string &name, int lag)
{
return AddVariableInternal(name, lag);
}
void
DynamicModel::compileDerivative(ofstream &code_file, int eq, int symb_id, int lag, ExprNodeOutputType output_type, map_idx_type &map_idx) const
{
......@@ -2240,7 +2246,7 @@ DynamicModel::toStatic(StaticModel &static_model) const
// Convert model local variables (need to be done first)
for (map<int, NodeID>::const_iterator it = local_variables_table.begin();
it != local_variables_table.end(); it++)
static_model.AddLocalParameter(symbol_table.getName(it->first), it->second->toStatic(static_model));
static_model.AddLocalVariable(symbol_table.getName(it->first), it->second->toStatic(static_model));
// Convert equations
for (vector<BinaryOpNode *>::const_iterator it = equations.begin();
......
......@@ -61,6 +61,9 @@ private:
public:
DynamicModel(SymbolTable &symbol_table_arg, NumericalConstants &num_constants);
//! Adds a variable node
/*! This implementation allows for non-zero lag */
virtual NodeID AddVariable(const string &name, int lag = 0);
//! Absolute value under which a number is considered to be zero
double cutoff;
//! The weight of the Markowitz criteria to determine the pivot in the linear solver (simul_NG1 from simulate.cc)
......
......@@ -247,7 +247,7 @@ VariableNode::computeDerivative(int varID)
cerr << "Impossible case!" << endl;
exit(EXIT_FAILURE);
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -577,20 +577,20 @@ UnaryOpNode::computeDerivative(int varID)
t12 = datatree.AddPlus(datatree.One, t11);
return datatree.AddDivide(darg, t12);
case oCosh:
t11 = datatree.AddSinH(arg);
t11 = datatree.AddSinh(arg);
return datatree.AddTimes(darg, t11);
case oSinh:
t11 = datatree.AddCosH(arg);
t11 = datatree.AddCosh(arg);
return datatree.AddTimes(darg, t11);
case oTanh:
t11 = datatree.AddTimes(this, this);
t12 = datatree.AddMinus(datatree.One, t11);
return datatree.AddTimes(darg, t12);
case oAcosh:
t11 = datatree.AddSinH(this);
t11 = datatree.AddSinh(this);
return datatree.AddDivide(darg, t11);
case oAsinh:
t11 = datatree.AddCosH(this);
t11 = datatree.AddCosh(this);
return datatree.AddDivide(darg, t11);
case oAtanh:
t11 = datatree.AddTimes(arg, arg);
......@@ -600,7 +600,7 @@ UnaryOpNode::computeDerivative(int varID)
t11 = datatree.AddPlus(this, this);
return datatree.AddDivide(darg, t11);
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -685,7 +685,7 @@ UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) c
case oSqrt:
return cost + 90;
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -889,8 +889,8 @@ UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException)
case oSqrt:
return(sqrt(v));
}
// Impossible
throw EvalException();
// Suppress GCC warning
exit(EXIT_FAILURE);
}
double
......@@ -951,26 +951,28 @@ UnaryOpNode::toStatic(DataTree &static_datatree) const
case oTan:
return static_datatree.AddTan(sarg);
case oAcos:
return static_datatree.AddACos(sarg);
return static_datatree.AddAcos(sarg);
case oAsin:
return static_datatree.AddASin(sarg);
return static_datatree.AddAsin(sarg);
case oAtan:
return static_datatree.AddATan(sarg);
return static_datatree.AddAtan(sarg);
case oCosh:
return static_datatree.AddCosH(sarg);
return static_datatree.AddCosh(sarg);
case oSinh:
return static_datatree.AddSinH(sarg);
return static_datatree.AddSinh(sarg);
case oTanh:
return static_datatree.AddTanH(sarg);
return static_datatree.AddTanh(sarg);
case oAcosh:
return static_datatree.AddACosH(sarg);
return static_datatree.AddAcosh(sarg);
case oAsinh:
return static_datatree.AddASinH(sarg);
return static_datatree.AddAsinh(sarg);
case oAtanh:
return static_datatree.AddATanH(sarg);
return static_datatree.AddAtanh(sarg);
case oSqrt:
return static_datatree.AddSqRt(sarg);
return static_datatree.AddSqrt(sarg);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1060,7 +1062,7 @@ BinaryOpNode::computeDerivative(int varID)
case oEqual:
return datatree.AddMinus(darg1, darg2);
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1100,7 +1102,7 @@ BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_t
case oMax:
return 100;
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1165,7 +1167,7 @@ BinaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab)
case oEqual:
return cost;
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1263,7 +1265,7 @@ BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2) throw (Eva
case oEqual:
throw EvalException();
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1475,7 +1477,7 @@ BinaryOpNode::toStatic(DataTree &static_datatree) const
case oEqual:
return static_datatree.AddEqual(sarg1, sarg2);
case oMax:
return static_datatree.AddMaX(sarg1, sarg2);
return static_datatree.AddMax(sarg1, sarg2);
case oMin:
return static_datatree.AddMin(sarg1, sarg2);
case oLess:
......@@ -1491,6 +1493,8 @@ BinaryOpNode::toStatic(DataTree &static_datatree) const
case oDifferent:
return static_datatree.AddDifferent(sarg1, sarg2);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1533,12 +1537,8 @@ TrinaryOpNode::computeDerivative(int varID)
case oNormcdf:
// normal pdf is inlined in the tree
NodeID y;
t11 = datatree.AddNumConstant("2");
t12 = datatree.AddNumConstant("3.141592653589793");
// 2 * pi
t13 = datatree.AddTimes(t11,t12);
// sqrt(2*pi)
t14 = datatree.AddSqRt(t13);
t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
// x - mu
t12 = datatree.AddMinus(arg1,arg2);
// y = (x-mu)/sigma
......@@ -1570,7 +1570,7 @@ TrinaryOpNode::computeDerivative(int varID)
// where t13 is the derivative of a standardized normal
return datatree.AddTimes(t11, t15);
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1587,7 +1587,7 @@ TrinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_
case oNormcdf:
return 100;
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1616,7 +1616,7 @@ TrinaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab)
case oNormcdf:
return cost+1000;
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1685,7 +1685,7 @@ TrinaryOpNode::eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v
cerr << "NORMCDF: eval not implemented" << endl;
exit(EXIT_FAILURE);
}
cerr << "Impossible case!" << endl;
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......@@ -1797,6 +1797,8 @@ TrinaryOpNode::toStatic(DataTree &static_datatree) const
case oNormcdf:
return static_datatree.AddNormcdf(sarg1, sarg2, sarg3);
}
// Suppress GCC warning
exit(EXIT_FAILURE);
}
......
......@@ -1078,7 +1078,7 @@ ParsingDriver::declare_and_init_model_local_variable(string *name, NodeID rhs)
error("Local model variable " + *name + " declared twice.");
}
model_tree->AddLocalParameter(*name, rhs);
model_tree->AddLocalVariable(*name, rhs);
delete name;
}
......@@ -1221,67 +1221,67 @@ ParsingDriver::add_tan(NodeID arg1)
NodeID
ParsingDriver::add_acos(NodeID arg1)
{
return data_tree->AddACos(arg1);
return data_tree->AddAcos(arg1);
}
NodeID
ParsingDriver::add_asin(NodeID arg1)
{
return data_tree->AddASin(arg1);
return data_tree->AddAsin(arg1);
}
NodeID
ParsingDriver::add_atan(NodeID arg1)
{
return data_tree->AddATan(arg1);
return data_tree->AddAtan(arg1);
}
NodeID
ParsingDriver::add_cosh(NodeID arg1)
{
return data_tree->AddCosH(arg1);
return data_tree->AddCosh(arg1);
}
NodeID
ParsingDriver::add_sinh(NodeID arg1)
{
return data_tree->AddSinH(arg1);
return data_tree->AddSinh(arg1);
}
NodeID
ParsingDriver::add_tanh(NodeID arg1)
{
return data_tree->AddTanH(arg1);
return data_tree->AddTanh(arg1);
}
NodeID
ParsingDriver::add_acosh(NodeID arg1)
{
return data_tree->AddACosH(arg1);
return data_tree->AddAcosh(arg1);
}
NodeID
ParsingDriver::add_asinh(NodeID arg1)
{
return data_tree->AddASinH(arg1);
return data_tree->AddAsinh(arg1);
}
NodeID
ParsingDriver::add_atanh(NodeID arg1)
{
return data_tree->AddATanH(arg1);
return data_tree->AddAtanh(arg1);
}
NodeID
ParsingDriver::add_sqrt(NodeID arg1)
{
return data_tree->AddSqRt(arg1);
return data_tree->AddSqrt(arg1);
}
NodeID
ParsingDriver::add_max(NodeID arg1, NodeID arg2)
{
return data_tree->AddMaX(arg1,arg2);
return data_tree->AddMax(arg1,arg2);
}
NodeID
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment