diff --git a/src/DataTree.cc b/src/DataTree.cc index f24c3030a6d53126e25ae35be1728d7f4c03d465..b315eb7514689fbd556e34961382c98ba08712ac 100644 --- a/src/DataTree.cc +++ b/src/DataTree.cc @@ -873,7 +873,7 @@ DataTree::minLagForSymbol(int symb_id) const } void -DataTree::writePowerDeriv(ostream &output) const +DataTree::writeCHelpersDefinition(ostream &output) const { if (isBinaryOpUsed(BinaryOpcode::powerDeriv)) output << "/*" << endl @@ -892,6 +892,12 @@ DataTree::writePowerDeriv(ostream &output) const << " return dxp;" << endl << " }" << endl << "}" << endl; + + if (isUnaryOpUsed(UnaryOpcode::sign)) + output << "double sign(double x)" << endl + << "{" << endl + << " return (x > 0) ? 1 : ((x < 0) ? -1 : 0);" << endl + << "}" << endl; } void @@ -915,10 +921,12 @@ DataTree::writePowerDerivJulia(ostream &output) const } void -DataTree::writePowerDerivHeader(ostream &output) const +DataTree::writeCHelpersDeclaration(ostream &output) const { if (isBinaryOpUsed(BinaryOpcode::powerDeriv)) output << "double getPowerDeriv(double x, double p, int k);" << endl; + if (isUnaryOpUsed(UnaryOpcode::sign)) + output << "double sign(double x);" << endl; } string diff --git a/src/DataTree.hh b/src/DataTree.hh index 6e9f818dbba12ecb640869c8de4138279f6eeb8a..61dfcfbf13e17c3714b8b9b718fc5c2f35b6ee44 100644 --- a/src/DataTree.hh +++ b/src/DataTree.hh @@ -280,12 +280,12 @@ public: //! Returns the minimum lag (as a negative number) of the given symbol in the whole data tree (and not only in the equations !!) /*! Returns 0 if the symbol is not used */ int minLagForSymbol(int symb_id) const; - //! Write getPowerDeriv in C (function body) - void writePowerDeriv(ostream &output) const; - //! Write getPowerDeriv in C (prototype) - void writePowerDerivHeader(ostream &output) const; //! Write getPowerDeriv in Julia void writePowerDerivJulia(ostream &output) const; + //! Writes definitions of C function helpers (getPowerDeriv(), sign()) + void writeCHelpersDefinition(ostream &output) const; + //! Writes declarations of C function helpers (getPowerDeriv(), sign()) + void writeCHelpersDeclaration(ostream &output) const; //! Thrown when trying to access an unknown variable by deriv_id class UnknownDerivIDException { diff --git a/src/DynamicModel.cc b/src/DynamicModel.cc index ea98032b0542efed4f97d637d6140cbfffe2ac09..f196862602f4c760e90788bc6337aa36c7bb022b 100644 --- a/src/DynamicModel.cc +++ b/src/DynamicModel.cc @@ -631,7 +631,7 @@ DynamicModel::writeDynamicPerBlockCFiles(const string &basename) const << endl; // Write function definition if BinaryOpcode::powerDeriv is used - writePowerDerivHeader(output); + writeCHelpersDeclaration(output); output << endl; @@ -1447,7 +1447,7 @@ DynamicModel::writeDynamicCFile(const string &basename) const << endl; // Write function definition if BinaryOpcode::powerDeriv is used - writePowerDeriv(output); + writeCHelpersDefinition(output); output << endl; @@ -1671,7 +1671,7 @@ DynamicModel::writeDynamicBlockCFile(const string &basename) const output << R"(#include "dynamic_)" << blk+1 << R"(.h")" << endl; output << endl; - writePowerDeriv(output); + writeCHelpersDefinition(output); output << endl << "void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])" << endl diff --git a/src/ExprNode.cc b/src/ExprNode.cc index d1ab86901c8286099dab7c8f0e993abfefedfedd..7c512ab8ade0ca82fe94e87384e829da0b0633dd 100644 --- a/src/ExprNode.cc +++ b/src/ExprNode.cc @@ -2693,10 +2693,10 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, output << "abs"; break; case UnaryOpcode::sign: - if (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel) - output << "copysign"; - else - output << "sign"; + /* C does not have a sign() function, and copysign() is not suitable + because it does not handle zero correctly, so we define our own sign() + helper function, see DataTree::writeCHelpersDefinition() */ + output << "sign"; break; case UnaryOpcode::steadyState: ExprNodeOutputType new_output_type; @@ -2787,8 +2787,6 @@ UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type, && arg->precedence(output_type, temporary_terms) < precedence(output_type, temporary_terms))) { output << LEFT_PAR(output_type); - if (op_code == UnaryOpcode::sign && (output_type == ExprNodeOutputType::CDynamicModel || output_type == ExprNodeOutputType::CStaticModel)) - output << "1.0,"; close_parenthesis = true; } diff --git a/src/StaticModel.cc b/src/StaticModel.cc index ba9d28f640b2ce00d11a6cc20e9dd69e3b3c0960..37cae8b17aed5c9c11202a535fd4763fac2e35d7 100644 --- a/src/StaticModel.cc +++ b/src/StaticModel.cc @@ -311,7 +311,7 @@ StaticModel::writeStaticPerBlockCFiles(const string &basename) const << endl; // Write function definition if BinaryOpcode::powerDeriv is used - writePowerDerivHeader(output); + writeCHelpersDeclaration(output); output << endl; @@ -1730,7 +1730,7 @@ StaticModel::writeStaticCFile(const string &basename) const << endl; // Write function definition if BinaryOpcode::powerDeriv is used - writePowerDeriv(output); + writeCHelpersDefinition(output); output << endl; @@ -1914,7 +1914,7 @@ StaticModel::writeStaticBlockCFile(const string &basename) const output << R"(#include "static_)" << blk+1 << R"(.h")" << endl; output << endl; - writePowerDeriv(output); + writeCHelpersDefinition(output); output << endl << "void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])" << endl