diff --git a/preprocessor/ExprNode.cc b/preprocessor/ExprNode.cc index ee47e18514685e1e3b5d6a3d5d9845a4be9712b0..710641de7dc0d54991de986713045d12c7581beb 100644 --- a/preprocessor/ExprNode.cc +++ b/preprocessor/ExprNode.cc @@ -319,6 +319,18 @@ NumConstNode::maxExoLead() const return 0; } +int +NumConstNode::maxEndoLag() const +{ + return 0; +} + +int +NumConstNode::maxExoLag() const +{ + return 0; +} + NodeID NumConstNode::decreaseLeadsLags(int n) const { @@ -836,6 +848,34 @@ VariableNode::maxExoLead() const } } +int +VariableNode::maxEndoLag() const +{ + switch (type) + { + case eEndogenous: + return max(-lag, 0); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->maxEndoLag(); + default: + return 0; + } +} + +int +VariableNode::maxExoLag() const +{ + switch (type) + { + case eExogenous: + return max(-lag, 0); + case eModelLocalVariable: + return datatree.local_variables_table[symb_id]->maxExoLag(); + default: + return 0; + } +} + NodeID VariableNode::decreaseLeadsLags(int n) const { @@ -887,6 +927,7 @@ NodeID VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { VariableNode *substexpr; + NodeID value; subst_table_t::const_iterator it; int cur_lag; switch (type) @@ -923,7 +964,11 @@ VariableNode::substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector return substexpr; case eModelLocalVariable: - return datatree.local_variables_table[symb_id]->substituteEndoLagGreaterThanTwo(subst_table, neweqs); + value = datatree.local_variables_table[symb_id]; + if (value->maxEndoLag() <= 1) + return const_cast<VariableNode *>(this); + else + return value->substituteEndoLagGreaterThanTwo(subst_table, neweqs); default: return const_cast<VariableNode *>(this); } @@ -955,6 +1000,7 @@ NodeID VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const { VariableNode *substexpr; + NodeID value; subst_table_t::const_iterator it; int cur_lag; switch (type) @@ -991,7 +1037,11 @@ VariableNode::substituteExoLag(subst_table_t &subst_table, vector<BinaryOpNode * return substexpr; case eModelLocalVariable: - return datatree.local_variables_table[symb_id]->substituteExoLag(subst_table, neweqs); + value = datatree.local_variables_table[symb_id]; + if (value->maxExoLag() == 0) + return const_cast<VariableNode *>(this); + else + return value->substituteExoLag(subst_table, neweqs); default: return const_cast<VariableNode *>(this); } @@ -1669,6 +1719,18 @@ UnaryOpNode::maxExoLead() const return arg->maxExoLead(); } +int +UnaryOpNode::maxEndoLag() const +{ + return arg->maxEndoLag(); +} + +int +UnaryOpNode::maxExoLag() const +{ + return arg->maxExoLag(); +} + NodeID UnaryOpNode::decreaseLeadsLags(int n) const { @@ -2659,6 +2721,18 @@ BinaryOpNode::maxExoLead() const return max(arg1->maxExoLead(), arg2->maxExoLead()); } +int +BinaryOpNode::maxEndoLag() const +{ + return max(arg1->maxEndoLag(), arg2->maxEndoLag()); +} + +int +BinaryOpNode::maxExoLag() const +{ + return max(arg1->maxExoLag(), arg2->maxExoLag()); +} + NodeID BinaryOpNode::decreaseLeadsLags(int n) const { @@ -3149,6 +3223,18 @@ TrinaryOpNode::maxExoLead() const return max(arg1->maxExoLead(), max(arg2->maxExoLead(), arg3->maxExoLead())); } +int +TrinaryOpNode::maxEndoLag() const +{ + return max(arg1->maxEndoLag(), max(arg2->maxEndoLag(), arg3->maxEndoLag())); +} + +int +TrinaryOpNode::maxExoLag() const +{ + return max(arg1->maxExoLag(), max(arg2->maxExoLag(), arg3->maxExoLag())); +} + NodeID TrinaryOpNode::decreaseLeadsLags(int n) const { @@ -3373,6 +3459,26 @@ UnknownFunctionNode::maxExoLead() const return val; } +int +ExternalFunctionNode::maxEndoLag() const +{ + int val = 0; + for (vector<NodeID>::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + val = max(val, (*it)->maxEndoLag()); + return val; +} + +int +ExternalFunctionNode::maxExoLag() const +{ + int val = 0; + for (vector<NodeID>::const_iterator it = arguments.begin(); + it != arguments.end(); it++) + val = max(val, (*it)->maxExoLag()); + return val; +} + NodeID UnknownFunctionNode::decreaseLeadsLags(int n) const { diff --git a/preprocessor/ExprNode.hh b/preprocessor/ExprNode.hh index a9179e04d220e4fda7c695dea8d1ec7371e7cb50..573a9e7958ab846739f86fc4fcaa91dac03ede23 100644 --- a/preprocessor/ExprNode.hh +++ b/preprocessor/ExprNode.hh @@ -237,6 +237,14 @@ public: /*! Always returns a non-negative value */ virtual int maxExoLead() const = 0; + //! Returns the maximum lag of endogenous in this expression + /*! Always returns a non-negative value */ + virtual int maxEndoLag() const = 0; + + //! Returns the maximum lag of exogenous in this expression + /*! Always returns a non-negative value */ + virtual int maxExoLag() const = 0; + //! Returns a new expression where all the leads/lags have been shifted backwards by the same amount /*! Only acts on endogenous, exogenous, exogenous det @@ -363,6 +371,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; @@ -381,6 +391,7 @@ private: //! Id from the symbol table const int symb_id; const SymbolType type; + //! A positive value is a lead, a negative is a lag const int lag; virtual NodeID computeDerivative(int deriv_id); public: @@ -407,6 +418,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; @@ -463,6 +476,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; //! Creates another UnaryOpNode with the same opcode, but with a possibly different datatree and argument @@ -528,6 +543,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; //! Creates another BinaryOpNode with the same opcode, but with a possibly different datatree and arguments @@ -575,6 +592,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; //! Creates another TrinaryOpNode with the same opcode, but with a possibly different datatree and arguments @@ -616,6 +635,8 @@ public: virtual NodeID getChainRuleDerivative(int deriv_id, const map<int, NodeID> &recursive_variables); virtual int maxEndoLead() const; virtual int maxExoLead() const; + virtual int maxEndoLag() const; + virtual int maxExoLag() const; virtual NodeID decreaseLeadsLags(int n) const; virtual NodeID substituteEndoLeadGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const; virtual NodeID substituteEndoLagGreaterThanTwo(subst_table_t &subst_table, vector<BinaryOpNode *> &neweqs) const;