Skip to content
Snippets Groups Projects
Commit 4f6a3669 authored by Houtan Bastani's avatar Houtan Bastani
Browse files

Added expression sharing for external functions

parent 28894927
No related branches found
No related tags found
No related merge requests found
...@@ -466,6 +466,11 @@ NodeID ...@@ -466,6 +466,11 @@ NodeID
DataTree::AddExternalFunction(int symb_id, const vector<NodeID> &arguments) DataTree::AddExternalFunction(int symb_id, const vector<NodeID> &arguments)
{ {
assert(symbol_table.getType(symb_id) == eExternalFunction); assert(symbol_table.getType(symb_id) == eExternalFunction);
external_function_node_map_type::iterator it = external_function_node_map.find(make_pair(arguments, symb_id));
if (it != external_function_node_map.end())
return it->second;
return new ExternalFunctionNode(*this, symb_id, arguments); return new ExternalFunctionNode(*this, symb_id, arguments);
} }
...@@ -473,6 +478,13 @@ NodeID ...@@ -473,6 +478,13 @@ NodeID
DataTree::AddFirstDerivExternalFunctionNode(int top_level_symb_id, const vector<NodeID> &arguments, int input_index) DataTree::AddFirstDerivExternalFunctionNode(int top_level_symb_id, const vector<NodeID> &arguments, int input_index)
{ {
assert(symbol_table.getType(top_level_symb_id) == eExternalFunction); assert(symbol_table.getType(top_level_symb_id) == eExternalFunction);
first_deriv_external_function_node_map_type::iterator it =
first_deriv_external_function_node_map.find(make_pair(make_pair(arguments, input_index),
top_level_symb_id));
if (it != first_deriv_external_function_node_map.end())
return it->second;
return new FirstDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index); return new FirstDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index);
} }
...@@ -480,6 +492,14 @@ NodeID ...@@ -480,6 +492,14 @@ NodeID
DataTree::AddSecondDerivExternalFunctionNode(int top_level_symb_id, const vector<NodeID> &arguments, int input_index1, int input_index2) DataTree::AddSecondDerivExternalFunctionNode(int top_level_symb_id, const vector<NodeID> &arguments, int input_index1, int input_index2)
{ {
assert(symbol_table.getType(top_level_symb_id) == eExternalFunction); assert(symbol_table.getType(top_level_symb_id) == eExternalFunction);
second_deriv_external_function_node_map_type::iterator it =
second_deriv_external_function_node_map.find(make_pair(make_pair(arguments,
make_pair(input_index1, input_index2)),
top_level_symb_id));
if (it != second_deriv_external_function_node_map.end())
return it->second;
return new SecondDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index1, input_index2); return new SecondDerivExternalFunctionNode(*this, top_level_symb_id, arguments, input_index1, input_index2);
} }
...@@ -541,3 +561,36 @@ DataTree::isTrinaryOpUsed(TrinaryOpcode opcode) const ...@@ -541,3 +561,36 @@ DataTree::isTrinaryOpUsed(TrinaryOpcode opcode) const
return false; return false;
} }
bool
DataTree::isExternalFunctionUsed(int symb_id) const
{
for (external_function_node_map_type::const_iterator it = external_function_node_map.begin();
it != external_function_node_map.end(); it++)
if (it->first.second == symb_id)
return true;
return false;
}
bool
DataTree::isFirstDerivExternalFunctionUsed(int symb_id) const
{
for (first_deriv_external_function_node_map_type::const_iterator it = first_deriv_external_function_node_map.begin();
it != first_deriv_external_function_node_map.end(); it++)
if (it->first.second == symb_id)
return true;
return false;
}
bool
DataTree::isSecondDerivExternalFunctionUsed(int symb_id) const
{
for (second_deriv_external_function_node_map_type::const_iterator it = second_deriv_external_function_node_map.begin();
it != second_deriv_external_function_node_map.end(); it++)
if (it->first.second == symb_id)
return true;
return false;
}
...@@ -65,6 +65,12 @@ protected: ...@@ -65,6 +65,12 @@ protected:
binary_op_node_map_type binary_op_node_map; binary_op_node_map_type binary_op_node_map;
typedef map<pair<pair<pair<NodeID, NodeID>, NodeID>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_type; typedef map<pair<pair<pair<NodeID, NodeID>, NodeID>, TrinaryOpcode>, TrinaryOpNode *> trinary_op_node_map_type;
trinary_op_node_map_type trinary_op_node_map; trinary_op_node_map_type trinary_op_node_map;
typedef map<pair<vector<NodeID>, int>, ExternalFunctionNode *> external_function_node_map_type;
external_function_node_map_type external_function_node_map;
typedef map<pair<pair<vector<NodeID>, int>, int>, FirstDerivExternalFunctionNode *> first_deriv_external_function_node_map_type;
first_deriv_external_function_node_map_type first_deriv_external_function_node_map;
typedef map<pair<pair<vector<NodeID>, pair<int, int> >, int>, SecondDerivExternalFunctionNode *> second_deriv_external_function_node_map_type;
second_deriv_external_function_node_map_type second_deriv_external_function_node_map;
//! Stores local variables value (maps symbol ID to corresponding node) //! Stores local variables value (maps symbol ID to corresponding node)
map<int, NodeID> local_variables_table; map<int, NodeID> local_variables_table;
...@@ -197,6 +203,12 @@ public: ...@@ -197,6 +203,12 @@ public:
bool isBinaryOpUsed(BinaryOpcode opcode) const; bool isBinaryOpUsed(BinaryOpcode opcode) const;
//! Checks if a given trinary op is used somewhere in the data tree //! Checks if a given trinary op is used somewhere in the data tree
bool isTrinaryOpUsed(TrinaryOpcode opcode) const; bool isTrinaryOpUsed(TrinaryOpcode opcode) const;
//! Checks if a given external function is used somewhere in the data tree
bool isExternalFunctionUsed(int symb_id) const;
//! Checks if a given first derivative external function is used somewhere in the data tree
bool isFirstDerivExternalFunctionUsed(int symb_id) const;
//! Checks if a given second derivative external function is used somewhere in the data tree
bool isSecondDerivExternalFunctionUsed(int symb_id) const;
//! Thrown when trying to access an unknown variable by deriv_id //! Thrown when trying to access an unknown variable by deriv_id
class UnknownDerivIDException class UnknownDerivIDException
{ {
... ...
......
...@@ -3181,6 +3181,8 @@ ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg, ...@@ -3181,6 +3181,8 @@ ExternalFunctionNode::ExternalFunctionNode(DataTree &datatree_arg,
symb_id(symb_id_arg), symb_id(symb_id_arg),
arguments(arguments_arg) arguments(arguments_arg)
{ {
// Add myself to the external function map
datatree.external_function_node_map[make_pair(arguments,symb_id)] = this;
} }
void void
...@@ -3428,6 +3430,8 @@ FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatre ...@@ -3428,6 +3430,8 @@ FirstDerivExternalFunctionNode::FirstDerivExternalFunctionNode(DataTree &datatre
ExternalFunctionNode(datatree_arg, top_level_symb_id_arg, arguments_arg), ExternalFunctionNode(datatree_arg, top_level_symb_id_arg, arguments_arg),
inputIndex(inputIndex_arg) inputIndex(inputIndex_arg)
{ {
// Add myself to the first derivative external function map
datatree.first_deriv_external_function_node_map[make_pair(make_pair(arguments,inputIndex),symb_id)] = this;
} }
NodeID NodeID
...@@ -3492,6 +3496,8 @@ SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datat ...@@ -3492,6 +3496,8 @@ SecondDerivExternalFunctionNode::SecondDerivExternalFunctionNode(DataTree &datat
inputIndex1(inputIndex1_arg), inputIndex1(inputIndex1_arg),
inputIndex2(inputIndex2_arg) inputIndex2(inputIndex2_arg)
{ {
// Add myself to the second derivative external function map
datatree.second_deriv_external_function_node_map[make_pair(make_pair(arguments,make_pair(inputIndex1,inputIndex2)),symb_id)] = this;
} }
NodeID NodeID
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment