ExprNode.hh 27.4 KB
Newer Older
1
/*
sebastien's avatar
trunk:  
sebastien committed
2
 * Copyright (C) 2007-2009 Dynare Team
3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
 *
 * This file is part of Dynare.
 *
 * Dynare is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Dynare is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
 */

#ifndef _EXPR_NODE_HH
#define _EXPR_NODE_HH

using namespace std;

#include <set>
#include <map>
#include <vector>
28
#include <ostream>
29 30 31 32 33

#include "SymbolTable.hh"
#include "CodeInterpreter.hh"

class DataTree;
sebastien's avatar
sebastien committed
34 35
class VariableNode;
class BinaryOpNode;
36 37 38

typedef class ExprNode *NodeID;

39
struct Model_Block;
40 41 42 43 44 45

struct ExprNodeLess;

//! Type for set of temporary terms
/*! They are ordered by index number thanks to ExprNodeLess */
typedef set<NodeID, ExprNodeLess> temporary_terms_type;
46

47 48
typedef map<int,int> map_idx_type;

49 50 51 52
//! Type for evaluation contexts
/*! The key is a symbol id. Lags are assumed to be null */
typedef map<int, double> eval_context_type;

53 54 55
//! Possible types of output when writing ExprNode(s)
enum ExprNodeOutputType
  {
56 57 58 59 60 61 62 63 64 65 66
    oMatlabStaticModel,                          //!< Matlab code, static model declarations
    oMatlabDynamicModel,                         //!< Matlab code, dynamic model declarations
    oMatlabStaticModelSparse,                    //!< Matlab code, static block decomposed mode declaration
    oMatlabDynamicModelSparse,                   //!< Matlab code, dynamic block decomposed mode declaration
    oCDynamicModel,                              //!< C code, dynamic model declarations
    oMatlabOutsideModel,                         //!< Matlab code, outside model block (for example in initval)
    oLatexStaticModel,                           //!< LaTeX code, static model declarations
    oLatexDynamicModel,				             //!< LaTeX code, dynamic model declarations
    oLatexDynamicSteadyStateOperator,            //!< LaTeX code, dynamic model steady state declarations
	oMatlabDynamicSteadyStateOperator,           //!< Matlab code, dynamic model steady state declarations
	oMatlabDynamicModelSparseSteadyStateOperator //!< Matlab code, dynamic block decomposed mode steady state declarations
67 68
  };

69 70 71 72
#define IS_MATLAB(output_type) ((output_type) == oMatlabStaticModel     \
                                || (output_type) == oMatlabDynamicModel \
                                || (output_type) == oMatlabOutsideModel \
                                || (output_type) == oMatlabStaticModelSparse \
73 74 75
                                || (output_type) == oMatlabDynamicModelSparse \
								|| (output_type) == oMatlabDynamicSteadyStateOperator \
								|| (output_type) == oMatlabDynamicModelSparseSteadyStateOperator)
76

77
#define IS_C(output_type) ((output_type) == oCDynamicModel)
78 79

#define IS_LATEX(output_type) ((output_type) == oLatexStaticModel       \
80
                               || (output_type) == oLatexDynamicModel \
81
							   || (output_type) == oLatexDynamicSteadyStateOperator)
82

83
/* Equal to 1 for Matlab langage, or to 0 for C language. Not defined for LaTeX.
84
   In Matlab, array indexes begin at 1, while they begin at 0 in C */
85
#define ARRAY_SUBSCRIPT_OFFSET(output_type) ((int) IS_MATLAB(output_type))
86

87 88 89
// Left and right array subscript delimiters: '(' and ')' for Matlab, '[' and ']' for C
#define LEFT_ARRAY_SUBSCRIPT(output_type) (IS_MATLAB(output_type) ? '(' : '[')
#define RIGHT_ARRAY_SUBSCRIPT(output_type) (IS_MATLAB(output_type) ? ')' : ']')
90

91 92 93
// Left and right parentheses
#define LEFT_PAR(output_type) (IS_LATEX(output_type) ? "\\left(" : "(")
#define RIGHT_PAR(output_type) (IS_LATEX(output_type) ? "\\right)" : ")")
94 95 96 97

// Computing cost above which a node can be declared a temporary term
#define MIN_COST_MATLAB (40*90)
#define MIN_COST_C (40*4)
sebastien's avatar
sebastien committed
98
#define MIN_COST(is_matlab) ((is_matlab) ? MIN_COST_MATLAB : MIN_COST_C)
99 100 101 102 103

//! Base class for expression nodes
class ExprNode
{
  friend class DataTree;
sebastien's avatar
sebastien committed
104
  friend class DynamicModel;
105
  friend class StaticDllModel;
106 107 108 109 110 111 112 113
  friend class ExprNodeLess;
  friend class NumConstNode;
  friend class VariableNode;
  friend class UnaryOpNode;
  friend class BinaryOpNode;
  friend class TrinaryOpNode;

private:
114
  //! Computes derivative w.r. to a derivation ID (but doesn't store it in derivatives map)
115
  /*! You shoud use getDerivative() to get the benefit of symbolic a priori and of caching */
116
  virtual NodeID computeDerivative(int deriv_id) = 0;
117 118 119 120 121 122 123 124

protected:
  //! Reference to the enclosing DataTree
  DataTree &datatree;

  //! Index number
  int idx;

sebastien's avatar
sebastien committed
125 126 127
  //! Is the data member non_null_derivatives initialized ?
  bool preparedForDerivation;

128
  //! Set of derivation IDs with respect to which the derivative is potentially non-null
129 130 131 132 133 134 135 136 137 138 139 140 141
  set<int> non_null_derivatives;

  //! Used for caching of first order derivatives (when non-null)
  map<int, NodeID> derivatives;

  //! Cost of computing current node
  /*! Nodes included in temporary_terms are considered having a null cost */
  virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;

public:
  ExprNode(DataTree &datatree_arg);
  virtual ~ExprNode();

sebastien's avatar
sebastien committed
142 143 144
  //! Initializes data member non_null_derivatives
  virtual void prepareForDerivation() = 0;

145
  //! Returns derivative w.r. to derivation ID
146
  /*! Uses a symbolic a priori to pre-detect null derivatives, and caches the result for other derivatives (to avoid computing it several times)
147
    For an equal node, returns the derivative of lhs minus rhs */
148
  NodeID getDerivative(int deriv_id);
149

150 151 152 153 154 155
  //! Computes derivatives by applying the chain rule for some variables
  /*!
    \param deriv_id The derivation ID with respect to which we are derivating
    \param recursive_variables Contains the derivation ID for which chain rules must be applied. Keys are derivation IDs, values are equations of the form x=f(y) where x is the key variable and x doesn't appear in y
  */
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables) = 0;
ferhat's avatar
ferhat committed
156

157 158 159 160 161 162 163 164 165
  //! Returns precedence of node
  /*! Equals 100 for constants, variables, unary ops, and temporary terms */
  virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;

  //! Fills temporary_terms set, using reference counts
  /*! A node will be marked as a temporary term if it is referenced at least two times (i.e. has at least two parents), and has a computing cost (multiplied by reference count) greater to datatree.min_cost */
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;

  //! Writes output of node, using a Txxx notation for nodes in temporary_terms
166
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const = 0;
167 168 169 170

  //! Writes output of node (with no temporary terms and with "outside model" output type)
  void writeOutput(ostream &output);

sebastien's avatar
sebastien committed
171 172 173 174 175 176 177 178 179
  //! Computes the set of all variables of a given symbol type in the expression
  /*!
    Variables are stored as integer pairs of the form (symb_id, lag).
    They are added to the set given in argument.
    Note that model local variables are substituted by their expression in the computation
    (and added if type_arg = ModelLocalVariable).
  */
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const = 0;

180
  //! Computes the set of endogenous variables in the expression
181
  /*!
182
    Endogenous are stored as integer pairs of the form (type_specific_id, lag).
183
    They are added to the set given in argument.
sebastien's avatar
sebastien committed
184
    Note that model local variables are substituted by their expression in the computation.
185
  */
sebastien's avatar
sebastien committed
186
  virtual void collectEndogenous(set<pair<int, int> > &result) const;
187 188 189

  //! Computes the set of exogenous variables in the expression
  /*!
190
    Exogenous are stored as integer pairs of the form (type_specific_id, lag).
191
    They are added to the set given in argument.
sebastien's avatar
sebastien committed
192
    Note that model local variables are substituted by their expression in the computation.
193
  */
sebastien's avatar
sebastien committed
194 195 196 197 198 199 200 201
  virtual void collectExogenous(set<pair<int, int> > &result) const;

  //! Computes the set of model local variables in the expression
  /*!
    Symbol IDs of these model local variables are added to the set given in argument.
    Note that this method is called recursively on the expressions associated to the model local variables detected.
  */
  virtual void collectModelLocalVariables(set<int> &result) const;
202

203
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const = 0;
204 205
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
                                     temporary_terms_type &temporary_terms,
206
                                     map<NodeID, pair<int, int> > &first_occurence,
207 208
                                     int Curr_block,
                                     Model_Block *ModelBlock,
209
                                     int equation,
210 211 212
                                     map_idx_type &map_idx) const;

  class EvalException
ferhat's avatar
ferhat committed
213

214 215 216 217
  {
  };

  virtual double eval(const eval_context_type &eval_context) const throw (EvalException) = 0;
218
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const = 0;
sebastien's avatar
sebastien committed
219 220 221 222 223 224
  //! Creates a static version of this node
  /*!
    This method duplicates the current node by creating a similar node from which all leads/lags have been stripped,
    adds the result in the static_datatree argument (and not in the original datatree), and returns it.
  */
  virtual NodeID toStatic(DataTree &static_datatree) const = 0;
ferhat's avatar
ferhat committed
225
  //! Try to normalize an equation linear in its endogenous variable
226
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > > &List_of_Op_RHS) const = 0;
sebastien's avatar
sebastien committed
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270

  //! Returns the maximum lead of endogenous in this expression
  /*! Always returns a non-negative value */
  virtual int maxEndoLead() const = 0;

  //! Returns a new expression where all the leads/lags have been shifted backwards by the same amount
  /*!
    Only acts on endogenous, exogenous, exogenous det
    \param[in] n The number of lags by which to shift
    \return The same expression except that leads/lags have been shifted backwards
  */
  virtual NodeID decreaseLeadsLags(int n) const = 0;

  //! Type for the substitution map used in the process of creating auxiliary vars for leads >= 2
  typedef map<const ExprNode *, const VariableNode *> subst_table_t;

  //! Creates auxiliary lead variables corresponding to this expression
  /*! 
    If maximum endogenous lead >= 3, this method will also create intermediary auxiliary var, and will add the equations of the form aux1 = aux2(+1) to the substitution table.
    \pre This expression is assumed to have maximum endogenous lead >= 2
    \param[in,out] subst_table The table to which new auxiliary variables and their correspondance will be added
    \return The new variable node corresponding to the current expression
  */
  VariableNode *createLeadAuxiliaryVarForMyself(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;

  //! Constructs a new expression where sub-expressions with max endo lead >= 2 have been replaced by auxiliary variables
  /*!
    \param[in,out] subst_table Map used to store expressions that have already be substituted and their corresponding variable, in order to avoid creating two auxiliary variables for the same sub-expr.
    \param[out] neweqs Equations to be added to the model to match the creation of auxiliary variables.

    If the method detects a sub-expr which needs to be substituted, two cases are possible:
    - if this expr is in the table, then it will use the corresponding variable and return the substituted expression
    - if this expr is not in the table, then it will create an auxiliary endogenous variable, add the substitution in the table and return the substituted expression

    \return A new equivalent expression where sub-expressions with max endo lead >= 2 have been replaced by auxiliary variables
    */
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;

  //! Constructs a new expression where endo variables with max endo lag >= 2 have been replaced by auxiliary variables
  /*!
    \param[in,out] subst_table Map used to store expressions that have already be substituted and their corresponding variable, in order to avoid creating two auxiliary variables for the same sub-expr.
    \param[out] neweqs Equations to be added to the model to match the creation of auxiliary variables.
  */
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const = 0;
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
};

//! Object used to compare two nodes (using their indexes)
struct ExprNodeLess
{
  bool operator()(NodeID arg1, NodeID arg2) const
  {
    return arg1->idx < arg2->idx;
  }
};

//! Numerical constant node
/*! The constant is necessarily non-negative (this is enforced at the NumericalConstants class level) */
class NumConstNode : public ExprNode
{
private:
  //! Id from numerical constants table
  const int id;
289
  virtual NodeID computeDerivative(int deriv_id);
290 291
public:
  NumConstNode(DataTree &datatree_arg, int id_arg);
sebastien's avatar
sebastien committed
292
  virtual void prepareForDerivation();
293
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
sebastien's avatar
sebastien committed
294
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
295
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
296
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
297
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
sebastien's avatar
sebastien committed
298
  virtual NodeID toStatic(DataTree &static_datatree) const;
299
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
300
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
301 302 303 304
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
305 306 307 308 309 310 311 312
};

//! Symbol or variable node
class VariableNode : public ExprNode
{
private:
  //! Id from the symbol table
  const int symb_id;
313
  const SymbolType type;
314
  const int lag;
sebastien's avatar
sebastien committed
315
  virtual NodeID computeDerivative(int deriv_id);
316
public:
sebastien's avatar
sebastien committed
317 318
  VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg);
  virtual void prepareForDerivation();
319
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
sebastien's avatar
sebastien committed
320
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
321
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
322 323 324 325 326 327
                                     temporary_terms_type &temporary_terms,
                                     map<NodeID, pair<int, int> > &first_occurence,
                                     int Curr_block,
                                     Model_Block *ModelBlock,
                                     int equation,
                                     map_idx_type &map_idx) const;
328
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
329
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
330
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
sebastien's avatar
sebastien committed
331
  virtual NodeID toStatic(DataTree &static_datatree) const;
332
  int get_symb_id() const { return symb_id; };
333
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
334
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
335 336 337 338
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
339 340 341 342 343 344 345 346
};

//! Unary operator node
class UnaryOpNode : public ExprNode
{
private:
  const NodeID arg;
  const UnaryOpcode op_code;
347
  virtual NodeID computeDerivative(int deriv_id);
348
  virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
349 350
  //! Returns the derivative of this node if darg is the derivative of the argument
  NodeID composeDerivatives(NodeID darg);
351 352
public:
  UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg);
sebastien's avatar
sebastien committed
353
  virtual void prepareForDerivation();
354 355 356 357
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
                                     temporary_terms_type &temporary_terms,
358
                                     map<NodeID, pair<int, int> > &first_occurence,
359 360
                                     int Curr_block,
                                     Model_Block *ModelBlock,
361
                                     int equation,
362
                                     map_idx_type &map_idx) const;
sebastien's avatar
sebastien committed
363
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
364
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
365 366
  static double eval_opcode(UnaryOpcode op_code, double v) throw (EvalException);
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
367
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
sebastien's avatar
sebastien committed
368 369 370 371 372
  //! Returns operand
  NodeID get_arg() const { return(arg); };
  //! Returns op code
  UnaryOpcode get_op_code() const { return(op_code); };
  virtual NodeID toStatic(DataTree &static_datatree) const;
373
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
374
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
375 376 377 378 379 380
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  //! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument
  NodeID buildSimilarUnaryOpNode(NodeID alt_arg, DataTree &alt_datatree) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
381 382 383 384 385 386 387 388
};

//! Binary operator node
class BinaryOpNode : public ExprNode
{
private:
  const NodeID arg1, arg2;
  const BinaryOpcode op_code;
389
  virtual NodeID computeDerivative(int deriv_id);
390
  virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
391 392
  //! Returns the derivative of this node if darg1 and darg2 are the derivatives of the arguments
  NodeID composeDerivatives(NodeID darg1, NodeID darg2);
393 394 395
public:
  BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
               BinaryOpcode op_code_arg, const NodeID arg2_arg);
sebastien's avatar
sebastien committed
396
  virtual void prepareForDerivation();
397 398 399 400 401
  virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
                                     temporary_terms_type &temporary_terms,
402
                                     map<NodeID, pair<int, int> > &first_occurence,
403 404
                                     int Curr_block,
                                     Model_Block *ModelBlock,
405
                                     int equation,
406
                                     map_idx_type &map_idx) const;
sebastien's avatar
sebastien committed
407
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
408
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
409 410
  static double eval_opcode(double v1, BinaryOpcode op_code, double v2) throw (EvalException);
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
411
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
412
  virtual NodeID Compute_RHS(NodeID arg1, NodeID arg2, int op, int op_type) const;
sebastien's avatar
sebastien committed
413 414 415 416 417 418 419
  //! Returns first operand
  NodeID get_arg1() const { return(arg1); };
  //! Returns second operand
  NodeID get_arg2() const { return(arg2); };
  //! Returns op code
  BinaryOpcode get_op_code() const { return(op_code); };
  virtual NodeID toStatic(DataTree &static_datatree) const;
420
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
421
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
422 423 424 425 426 427
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  //! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments
  NodeID buildSimilarBinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, DataTree &alt_datatree) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
428 429 430 431 432 433 434 435 436
};

//! Trinary operator node
class TrinaryOpNode : public ExprNode
{
  friend class ModelTree;
private:
  const NodeID arg1, arg2, arg3;
  const TrinaryOpcode op_code;
437
  virtual NodeID computeDerivative(int deriv_id);
438
  virtual int cost(const temporary_terms_type &temporary_terms, bool is_matlab) const;
439 440
  //! Returns the derivative of this node if darg1, darg2 and darg3 are the derivatives of the arguments
  NodeID composeDerivatives(NodeID darg1, NodeID darg2, NodeID darg3);
441 442 443
public:
  TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
		TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg);
sebastien's avatar
sebastien committed
444
  virtual void prepareForDerivation();
445 446 447 448 449
  virtual int precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
                                     temporary_terms_type &temporary_terms,
450
                                     map<NodeID, pair<int, int> > &first_occurence,
451 452
                                     int Curr_block,
                                     Model_Block *ModelBlock,
453
                                     int equation,
454
                                     map_idx_type &map_idx) const;
sebastien's avatar
sebastien committed
455
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
456
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
457 458
  static double eval_opcode(double v1, TrinaryOpcode op_code, double v2, double v3) throw (EvalException);
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
459
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
sebastien's avatar
sebastien committed
460
  virtual NodeID toStatic(DataTree &static_datatree) const;
461
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
462
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
463 464 465 466 467 468
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  //! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments
  NodeID buildSimilarTrinaryOpNode(NodeID alt_arg1, NodeID alt_arg2, NodeID alt_arg3, DataTree &alt_datatree) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
469 470 471 472 473 474 475 476
};

//! Unknown function node
class UnknownFunctionNode : public ExprNode
{
private:
  const int symb_id;
  const vector<NodeID> arguments;
477
  virtual NodeID computeDerivative(int deriv_id);
478 479 480
public:
  UnknownFunctionNode(DataTree &datatree_arg, int symb_id_arg,
                      const vector<NodeID> &arguments_arg);
sebastien's avatar
sebastien committed
481
  virtual void prepareForDerivation();
482 483 484 485
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count, temporary_terms_type &temporary_terms, bool is_matlab) const;
  virtual void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const;
  virtual void computeTemporaryTerms(map<NodeID, int> &reference_count,
                                     temporary_terms_type &temporary_terms,
486
                                     map<NodeID, pair<int, int> > &first_occurence,
487 488
                                     int Curr_block,
                                     Model_Block *ModelBlock,
489
                                     int equation,
490
                                     map_idx_type &map_idx) const;
sebastien's avatar
sebastien committed
491
  virtual void collectVariables(SymbolType type_arg, set<pair<int, int> > &result) const;
492
  virtual void collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const;
493
  virtual double eval(const eval_context_type &eval_context) const throw (EvalException);
494
  virtual void compile(ostream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx, bool dynamic, bool steady_dynamic) const;
sebastien's avatar
sebastien committed
495
  virtual NodeID toStatic(DataTree &static_datatree) const;
496
  virtual pair<int, NodeID> normalizeEquation(int symb_id_endo, vector<pair<int, pair<NodeID, NodeID> > >  &List_of_Op_RHS) const;
497
  virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables);
sebastien's avatar
sebastien committed
498 499 500 501
  virtual int maxEndoLead() const;
  virtual NodeID decreaseLeadsLags(int n) const;
  virtual NodeID substituteLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
  virtual NodeID substituteLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;
502 503 504
};

#endif