DataTree.hh 8.82 KB
Newer Older
1
/*
2
 * Copyright (C) 2003-2009 Dynare Team
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
 *
 * This file is part of Dynare.
 *
 * Dynare is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Dynare is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef _DATATREE_HH
#define _DATATREE_HH

using namespace std;

#include <string>
#include <map>
#include <list>
#include <sstream>
#include <iomanip>

#include "SymbolTable.hh"
#include "NumericalConstants.hh"
#include "VariableTable.hh"
#include "ExprNode.hh"

#define CONSTANTS_PRECISION 16

class DataTree
{
  friend class ExprNode;
  friend class NumConstNode;
  friend class VariableNode;
  friend class UnaryOpNode;
  friend class BinaryOpNode;
  friend class TrinaryOpNode;
  friend class UnknownFunctionNode;
protected:
  //! A reference to the symbol table
  SymbolTable &symbol_table;
  //! Reference to numerical constants table
  NumericalConstants &num_constants;

  typedef map<int, NodeID> num_const_node_map_type;
  num_const_node_map_type num_const_node_map;
sebastien's avatar
sebastien committed
55
56
  //! Pair (symbol_id, lag) used as key
  typedef map<pair<int, int>, NodeID> variable_node_map_type;
57
58
59
60
61
62
63
64
  variable_node_map_type variable_node_map;
  typedef map<pair<NodeID, int>, NodeID> unary_op_node_map_type;
  unary_op_node_map_type unary_op_node_map;
  typedef map<pair<pair<NodeID, NodeID>, int>, NodeID> binary_op_node_map_type;
  binary_op_node_map_type binary_op_node_map;
  typedef map<pair<pair<pair<NodeID, NodeID>,NodeID>, int>, NodeID> trinary_op_node_map_type;
  trinary_op_node_map_type trinary_op_node_map;

sebastien's avatar
sebastien committed
65
66
67
68
69
70
71
72
73
74
75
76
77
  //! Stores local variables value (maps symbol ID to corresponding node)
  map<int, NodeID> local_variables_table;

  //! Internal implementation of AddVariable(), without the check on the lag
  NodeID AddVariableInternal(const string &name, int lag);

private:
  typedef list<NodeID> node_list_type;
  //! The list of nodes
  node_list_type node_list;
  //! A counter for filling ExprNode's idx field
  int node_counter;

78
79
80
81
  inline NodeID AddPossiblyNegativeConstant(double val);
  inline NodeID AddUnaryOp(UnaryOpcode op_code, NodeID arg);
  inline NodeID AddBinaryOp(NodeID arg1, BinaryOpcode op_code, NodeID arg2);
  inline NodeID AddTrinaryOp(NodeID arg1, TrinaryOpcode op_code, NodeID arg2, NodeID arg3);
sebastien's avatar
sebastien committed
82

83
84
85
86
87
public:
  DataTree(SymbolTable &symbol_table_arg, NumericalConstants &num_constants_arg);
  virtual ~DataTree();
  //! The variable table
  VariableTable variable_table;
sebastien's avatar
sebastien committed
88
89
  //! Some predefined constants
  NodeID Zero, One, Two, MinusOne, NaN, Infinity, MinusInfinity, Pi;
90
91

  //! Raised when a local parameter is declared twice
sebastien's avatar
sebastien committed
92
  class LocalVariableException
93
94
95
  {
  public:
    string name;
sebastien's avatar
sebastien committed
96
    LocalVariableException(const string &name_arg) : name(name_arg) {}
97
98
  };

sebastien's avatar
sebastien committed
99
  //! Adds a numerical constant
100
  NodeID AddNumConstant(const string &value);
sebastien's avatar
sebastien committed
101
102
103
  //! Adds a variable
  /*! The default implementation of the method refuses any lag != 0 */
  virtual NodeID AddVariable(const string &name, int lag = 0);
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
  //! Adds "arg1+arg2" to model tree
  NodeID AddPlus(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1-arg2" to model tree
  NodeID AddMinus(NodeID iArg1, NodeID iArg2);
  //! Adds "-arg" to model tree
  NodeID AddUMinus(NodeID iArg1);
  //! Adds "arg1*arg2" to model tree
  NodeID AddTimes(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1/arg2" to model tree
  NodeID AddDivide(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1<arg2" to model tree
  NodeID AddLess(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1>arg2" to model tree
  NodeID AddGreater(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1<=arg2" to model tree
  NodeID AddLessEqual(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1>=arg2" to model tree
  NodeID AddGreaterEqual(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1==arg2" to model tree
  NodeID AddEqualEqual(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1!=arg2" to model tree
  NodeID AddDifferent(NodeID iArg1, NodeID iArg2);
  //! Adds "arg1^arg2" to model tree
  NodeID AddPower(NodeID iArg1, NodeID iArg2);
  //! Adds "exp(arg)" to model tree
  NodeID AddExp(NodeID iArg1);
  //! Adds "log(arg)" to model tree
  NodeID AddLog(NodeID iArg1);
  //! Adds "log10(arg)" to model tree
  NodeID AddLog10(NodeID iArg1);
  //! Adds "cos(arg)" to model tree
  NodeID AddCos(NodeID iArg1);
  //! Adds "sin(arg)" to model tree
  NodeID AddSin(NodeID iArg1);
  //! Adds "tan(arg)" to model tree
  NodeID AddTan(NodeID iArg1);
  //! Adds "acos(arg)" to model tree
sebastien's avatar
sebastien committed
141
  NodeID AddAcos(NodeID iArg1);
142
  //! Adds "asin(arg)" to model tree
sebastien's avatar
sebastien committed
143
  NodeID AddAsin(NodeID iArg1);
144
  //! Adds "atan(arg)" to model tree
sebastien's avatar
sebastien committed
145
  NodeID AddAtan(NodeID iArg1);
146
  //! Adds "cosh(arg)" to model tree
sebastien's avatar
sebastien committed
147
  NodeID AddCosh(NodeID iArg1);
148
  //! Adds "sinh(arg)" to model tree
sebastien's avatar
sebastien committed
149
  NodeID AddSinh(NodeID iArg1);
150
  //! Adds "tanh(arg)" to model tree
sebastien's avatar
sebastien committed
151
  NodeID AddTanh(NodeID iArg1);
152
  //! Adds "acosh(arg)" to model tree
sebastien's avatar
sebastien committed
153
  NodeID AddAcosh(NodeID iArg1);
154
  //! Adds "asinh(arg)" to model tree
sebastien's avatar
sebastien committed
155
  NodeID AddAsinh(NodeID iArg1);
156
  //! Adds "atanh(args)" to model tree
sebastien's avatar
sebastien committed
157
  NodeID AddAtanh(NodeID iArg1);
158
  //! Adds "sqrt(arg)" to model tree
sebastien's avatar
sebastien committed
159
  NodeID AddSqrt(NodeID iArg1);
160
  //! Adds "max(arg1,arg2)" to model tree
sebastien's avatar
sebastien committed
161
  NodeID AddMax(NodeID iArg1, NodeID iArg2);
162
163
164
165
166
167
  //! Adds "min(arg1,arg2)" to model tree
  NodeID AddMin(NodeID iArg1, NodeID iArg2);
  //! Adds "normcdf(arg1,arg2,arg3)" to model tree
  NodeID AddNormcdf(NodeID iArg1, NodeID iArg2, NodeID iArg3);
  //! Adds "arg1=arg2" to model tree
  NodeID AddEqual(NodeID iArg1, NodeID iArg2);
sebastien's avatar
sebastien committed
168
169
  //! Adds a model local variable with its value
  void AddLocalVariable(const string &name, NodeID value) throw (LocalVariableException);
170
171
172
  //! Adds an unknown function node
  /*! \todo Use a map to share identical nodes */
  NodeID AddUnknownFunction(const string &function_name, const vector<NodeID> &arguments);
sebastien's avatar
sebastien committed
173
174
  //! Fill eval context with values of local variables
  void fillEvalContext(eval_context_type &eval_context) const;
175
176
  //! Checks if a given symbol is used somewhere in the data tree
  bool isSymbolUsed(int symb_id) const;
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
};

inline NodeID
DataTree::AddPossiblyNegativeConstant(double v)
{
  bool neg = false;
  if (v < 0)
    {
      v = -v;
      neg = true;
    }
  ostringstream ost;
  ost << setprecision(CONSTANTS_PRECISION) << v;

  NodeID cnode = AddNumConstant(ost.str());

  if (neg)
    return AddUMinus(cnode);
  else
    return cnode;
}

inline NodeID
DataTree::AddUnaryOp(UnaryOpcode op_code, NodeID arg)
{
  // If the node already exists in tree, share it
  unary_op_node_map_type::iterator it = unary_op_node_map.find(make_pair(arg, op_code));
  if (it != unary_op_node_map.end())
    return it->second;

  // Try to reduce to a constant
  // Case where arg is a constant and op_code == oUminus (i.e. we're adding a negative constant) is skipped
  NumConstNode *carg = dynamic_cast<NumConstNode *>(arg);
  if (op_code != oUminus || carg == NULL)
    {
      try
        {
          double argval = arg->eval(eval_context_type());
          double val = UnaryOpNode::eval_opcode(op_code, argval);
          return AddPossiblyNegativeConstant(val);
        }
      catch(ExprNode::EvalException &e)
        {
        }
    }
  return new UnaryOpNode(*this, op_code, arg);
}

inline NodeID
DataTree::AddBinaryOp(NodeID arg1, BinaryOpcode op_code, NodeID arg2)
{
  binary_op_node_map_type::iterator it = binary_op_node_map.find(make_pair(make_pair(arg1, arg2), op_code));
  if (it != binary_op_node_map.end())
    return it->second;

  // Try to reduce to a constant
  try
    {
      double argval1 = arg1->eval(eval_context_type());
      double argval2 = arg2->eval(eval_context_type());
      double val = BinaryOpNode::eval_opcode(argval1, op_code, argval2);
      return AddPossiblyNegativeConstant(val);
    }
  catch(ExprNode::EvalException &e)
    {
    }
  return new BinaryOpNode(*this, arg1, op_code, arg2);
}

inline NodeID
DataTree::AddTrinaryOp(NodeID arg1, TrinaryOpcode op_code, NodeID arg2, NodeID arg3)
{
  trinary_op_node_map_type::iterator it = trinary_op_node_map.find(make_pair(make_pair(make_pair(arg1, arg2), arg3), op_code));
  if (it != trinary_op_node_map.end())
    return it->second;

  // Try to reduce to a constant
  try
    {
      double argval1 = arg1->eval(eval_context_type());
      double argval2 = arg2->eval(eval_context_type());
      double argval3 = arg3->eval(eval_context_type());
      double val = TrinaryOpNode::eval_opcode(argval1, op_code, argval2, argval3);
      return AddPossiblyNegativeConstant(val);
    }
  catch(ExprNode::EvalException &e)
    {
    }
  return new TrinaryOpNode(*this, arg1, op_code, arg2, arg3);
}

#endif