diff --git a/src/ExprNode.hh b/src/ExprNode.hh index 8e2b9d71492aba9b71dc05e88bde12cd15d2592f..5ecab5f445a20ff1cca73bf8828f017bc4bf367c 100644 --- a/src/ExprNode.hh +++ b/src/ExprNode.hh @@ -275,6 +275,9 @@ protected: return is_matlab ? min_cost_matlab : min_cost_c; }; + //! Initializes data member non_null_derivatives + virtual void prepareForDerivation() = 0; + //! Cost of computing current node /*! Nodes included in temporary_terms are considered having a null cost */ virtual int cost(int cost, bool is_matlab) const; @@ -323,9 +326,6 @@ public: ExprNode(const ExprNode &) = delete; ExprNode &operator=(const ExprNode &) = delete; - //! Initializes data member non_null_derivatives - virtual void prepareForDerivation() = 0; - //! Returns derivative w.r. to derivation ID /*! Uses a symbolic a priori to pre-detect null derivatives, and caches the result for other derivatives (to avoid computing it several times) For an equal node, returns the derivative of lhs minus rhs */ @@ -845,11 +845,11 @@ private: expr_t computeDerivative(int deriv_id) override; expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; protected: + void prepareForDerivation() override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: NumConstNode(DataTree &datatree_arg, int idx_arg, int id_arg); - void prepareForDerivation() override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeJsonAST(ostream &output) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; @@ -917,11 +917,11 @@ private: expr_t computeDerivative(int deriv_id) override; expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override; protected: + void prepareForDerivation() override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: VariableNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, int lag_arg); - void prepareForDerivation() override; void writeOutput(ostream &output, ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms, const temporary_terms_idxs_t &temporary_terms_idxs, const deriv_node_temp_terms_t &tef_terms) const override; void writeJsonAST(ostream &output) const override; void writeJsonOutput(ostream &output, const temporary_terms_t &temporary_terms, const deriv_node_temp_terms_t &tef_terms, bool isdynamic) const override; @@ -984,6 +984,7 @@ public: class UnaryOpNode : public ExprNode { protected: + void prepareForDerivation() override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -1005,7 +1006,6 @@ private: expr_t composeDerivatives(expr_t darg, int deriv_id); public: UnaryOpNode(DataTree &datatree_arg, int idx_arg, UnaryOpcode op_code_arg, const expr_t arg_arg, int expectation_information_set_arg, int param1_symb_id_arg, int param2_symb_id_arg, string adl_param_name_arg, vector<int> adl_lags_arg); - void prepareForDerivation() override; void computeTemporaryTerms(const pair<int, int> &derivOrder, map<pair<int, int>, temporary_terms_t> &temp_terms_map, map<expr_t, pair<int, pair<int, int>>> &reference_count, @@ -1089,6 +1089,7 @@ public: class BinaryOpNode : public ExprNode { protected: + void prepareForDerivation() override; void matchVTCTPHelper(optional<int> &var_id, int &lag, optional<int> ¶m_id, double &constant, bool at_denominator) const override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; public: @@ -1107,7 +1108,6 @@ private: public: BinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg, BinaryOpcode op_code_arg, const expr_t arg2_arg, int powerDerivOrder); - void prepareForDerivation() override; int precedenceJson(const temporary_terms_t &temporary_terms) const override; int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override; void computeTemporaryTerms(const pair<int, int> &derivOrder, @@ -1239,6 +1239,7 @@ public: const expr_t arg1, arg2, arg3; const TrinaryOpcode op_code; protected: + void prepareForDerivation() override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: expr_t computeDerivative(int deriv_id) override; @@ -1251,7 +1252,6 @@ private: public: TrinaryOpNode(DataTree &datatree_arg, int idx_arg, const expr_t arg1_arg, TrinaryOpcode op_code_arg, const expr_t arg2_arg, const expr_t arg3_arg); - void prepareForDerivation() override; int precedence(ExprNodeOutputType output_type, const temporary_terms_t &temporary_terms) const override; void computeTemporaryTerms(const pair<int, int> &derivOrder, map<pair<int, int>, temporary_terms_t> &temp_terms_map, @@ -1347,6 +1347,7 @@ protected: class UnknownFunctionNameAndArgs { }; + void prepareForDerivation() override; //! Returns true if the given external function has been written as a temporary term bool alreadyWrittenAsTefTerm(int the_symb_id, const deriv_node_temp_terms_t &tef_terms) const; //! Returns the index in the tef_terms map of this external function @@ -1368,7 +1369,6 @@ protected: public: AbstractExternalFunctionNode(DataTree &datatree_arg, int idx_arg, int symb_id_arg, vector<expr_t> arguments_arg); - void prepareForDerivation() override; void computeTemporaryTerms(const pair<int, int> &derivOrder, map<pair<int, int>, temporary_terms_t> &temp_terms_map, map<expr_t, pair<int, pair<int, int>>> &reference_count, @@ -1575,7 +1575,6 @@ public: void computeBlockTemporaryTerms(int blk, int eq, vector<vector<temporary_terms_t>> &blocks_temporary_terms, map<expr_t, tuple<int, int, int>> &reference_count) const override; expr_t toStatic(DataTree &static_datatree) const override; - void prepareForDerivation() override; expr_t computeDerivative(int deriv_id) override; int maxEndoLead() const override; int maxExoLead() const override; @@ -1622,6 +1621,7 @@ public: expr_t removeTrendLeadLag(const map<int, expr_t> &trend_symbols_map) const override; expr_t substituteLogTransform(int orig_symb_id, int aux_symb_id) const override; protected: + void prepareForDerivation() override; void computeSubExprContainingVariable(int symb_id, int lag, set<expr_t> &contain_var) const override; private: expr_t computeChainRuleDerivative(int deriv_id, const map<int, BinaryOpNode *> &recursive_variables, map<pair<expr_t, int>, expr_t> &cache) override;