ExprNode.cc 56 KB
Newer Older
1
/*
sebastien's avatar
trunk:    
sebastien committed
2
 * Copyright (C) 2007-2009 Dynare Team
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
 *
 * This file is part of Dynare.
 *
 * Dynare is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Dynare is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Dynare.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <iostream>
#include <iterator>
#include <algorithm>

24
#include <cassert>
25
#include <cmath>
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43

#include "ExprNode.hh"
#include "DataTree.hh"

ExprNode::ExprNode(DataTree &datatree_arg) : datatree(datatree_arg)
{
  // Add myself to datatree
  datatree.node_list.push_back(this);

  // Set my index and increment counter
  idx = datatree.node_counter++;
}

ExprNode::~ExprNode()
{
}

NodeID
44
ExprNode::getDerivative(int deriv_id)
45
46
{
  // Return zero if derivative is necessarily null (using symbolic a priori)
47
  set<int>::const_iterator it = non_null_derivatives.find(deriv_id);
48
49
50
51
  if (it == non_null_derivatives.end())
    return datatree.Zero;

  // If derivative is stored in cache, use the cached value, otherwise compute it (and cache it)
52
  map<int, NodeID>::const_iterator it2 = derivatives.find(deriv_id);
53
54
55
56
  if (it2 != derivatives.end())
    return it2->second;
  else
    {
57
58
      NodeID d = computeDerivative(deriv_id);
      derivatives[deriv_id] = d;
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
      return d;
    }
}

int
ExprNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{
  // For a constant, a variable, or a unary op, the precedence is maximal
  return 100;
}

int
ExprNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
{
  // For a terminal node, the cost is null
  return 0;
}

void
ExprNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                temporary_terms_type &temporary_terms,
                                bool is_matlab) const
{
  // Nothing to do for a terminal node
}

void
ExprNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                temporary_terms_type &temporary_terms,
88
                                map<NodeID, pair<int, int> > &first_occurence,
89
90
                                int Curr_block,
                                Model_Block *ModelBlock,
91
                                int equation,
92
93
94
95
96
97
98
99
100
101
102
                                map_idx_type &map_idx) const
{
  // Nothing to do for a terminal node
}

void
ExprNode::writeOutput(ostream &output)
{
  writeOutput(output, oMatlabOutsideModel, temporary_terms_type());
}

sebastien's avatar
sebastien committed
103

104
105
106
107
108
109
110
111
112
113
114
NumConstNode::NumConstNode(DataTree &datatree_arg, int id_arg) :
  ExprNode(datatree_arg),
  id(id_arg)
{
  // Add myself to the num const map
  datatree.num_const_node_map[id] = this;

  // All derivatives are null, so non_null_derivatives is left empty
}

NodeID
115
NumConstNode::computeDerivative(int deriv_id)
116
117
118
119
{
  return datatree.Zero;
}

120
121
122
123
124
125
126
127
void
NumConstNode::collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<NumConstNode *>(this));
  if (it != temporary_terms.end())
    ModelBlock->Block_List[Curr_Block].Temporary_InUse->insert(idx);
}

128
129
130
131
132
133
void
NumConstNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
                          const temporary_terms_type &temporary_terms) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<NumConstNode *>(this));
  if (it != temporary_terms.end())
sebastien's avatar
sebastien committed
134
    if (output_type == oMatlabDynamicModelSparse)
135
136
137
138
139
140
141
142
143
144
145
146
147
148
      output << "T" << idx << "(it_)";
    else
      output << "T" << idx;
  else
    output << datatree.num_constants.get(id);
}

double
NumConstNode::eval(const eval_context_type &eval_context) const throw (EvalException)
{
  return(datatree.num_constants.getDouble(id));
}

void
sebastien's avatar
sebastien committed
149
NumConstNode::compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const
150
{
151
  CompileCode.write(&FLDC, sizeof(FLDC));
sebastien's avatar
sebastien committed
152
  double vard = datatree.num_constants.getDouble(id);
153
#ifdef DEBUGC
154
  cout << "FLDC " << vard << "\n";
155
#endif
156
  CompileCode.write(reinterpret_cast<char *>(&vard),sizeof(vard));
157
158
159
}

void
160
NumConstNode::collectEndogenous(set<pair<int, int> > &result) const
161
162
163
{
}

ferhat's avatar
ferhat committed
164
165
166
167
168
void
NumConstNode::collectExogenous(set<pair<int, int> > &result) const
{
}

sebastien's avatar
sebastien committed
169
170
171
172
173
174
NodeID
NumConstNode::toStatic(DataTree &static_datatree) const
{
  return static_datatree.AddNumConstant(datatree.num_constants.get(id));
}

ferhat's avatar
ferhat committed
175

176
VariableNode::VariableNode(DataTree &datatree_arg, int symb_id_arg, int lag_arg, int deriv_id_arg) :
177
178
  ExprNode(datatree_arg),
  symb_id(symb_id_arg),
sebastien's avatar
sebastien committed
179
  type(datatree.symbol_table.getType(symb_id_arg)),
180
181
  lag(lag_arg),
  deriv_id(deriv_id_arg)
182
183
{
  // Add myself to the variable map
sebastien's avatar
sebastien committed
184
  datatree.variable_node_map[make_pair(symb_id, lag)] = this;
185

sebastien's avatar
sebastien committed
186
  // It makes sense to allow a lead/lag on parameters: during steady state calibration, endogenous and parameters can be swapped
187
  assert(lag == 0 || (type != eModelLocalVariable && type != eModFileLocalVariable && type != eUnknownFunction));
sebastien's avatar
sebastien committed
188

189
190
191
192
193
194
195
  // Fill in non_null_derivatives
  switch(type)
    {
    case eEndogenous:
    case eExogenous:
    case eExogenousDet:
    case eParameter:
sebastien's avatar
sebastien committed
196
197
      // For a variable or a parameter, the only non-null derivative is with respect to itself
      non_null_derivatives.insert(deriv_id);
198
199
200
201
202
203
204
205
206
207
      break;
    case eModelLocalVariable:
      // Non null derivatives are those of the value of the local parameter
      non_null_derivatives = datatree.local_variables_table[symb_id]->non_null_derivatives;
      break;
    case eModFileLocalVariable:
      // Such a variable is never derived
      break;
    case eUnknownFunction:
      cerr << "Attempt to construct a VariableNode with an unknown function name" << endl;
208
      exit(EXIT_FAILURE);
209
210
211
212
    }
}

NodeID
213
VariableNode::computeDerivative(int deriv_id_arg)
214
215
216
217
218
219
{
  switch(type)
    {
    case eEndogenous:
    case eExogenous:
    case eExogenousDet:
sebastien's avatar
sebastien committed
220
    case eParameter:
221
      if (deriv_id == deriv_id_arg)
222
223
224
225
        return datatree.One;
      else
        return datatree.Zero;
    case eModelLocalVariable:
226
      return datatree.local_variables_table[symb_id]->getDerivative(deriv_id_arg);
227
228
    case eModFileLocalVariable:
      cerr << "ModFileLocalVariable is not derivable" << endl;
229
      exit(EXIT_FAILURE);
230
231
    case eUnknownFunction:
      cerr << "Impossible case!" << endl;
232
      exit(EXIT_FAILURE);
233
    }
sebastien's avatar
sebastien committed
234
  // Suppress GCC warning
235
  exit(EXIT_FAILURE);
236
237
}

238
239
240
241
242
243
void
VariableNode::collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<VariableNode *>(this));
  if (it != temporary_terms.end())
    ModelBlock->Block_List[Curr_Block].Temporary_InUse->insert(idx);
244
245
  if(type== eModelLocalVariable)
    datatree.local_variables_table[symb_id]->collectTemporary_terms(temporary_terms, ModelBlock, Curr_Block);
246
247
}

248
249
250
251
252
253
254
255
256
257
258
void
VariableNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
                          const temporary_terms_type &temporary_terms) const
{
  // If node is a temporary term
#ifdef DEBUGC
  cout << "write_ouput output_type=" << output_type << "\n";
#endif
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<VariableNode *>(this));
  if (it != temporary_terms.end())
    {
sebastien's avatar
sebastien committed
259
      if (output_type == oMatlabDynamicModelSparse)
260
261
262
263
264
265
        output << "T" << idx << "(it_)";
      else
        output << "T" << idx;
      return;
    }

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
  if (IS_LATEX(output_type))
    {
      output << datatree.symbol_table.getTeXName(symb_id);
      if (output_type == oLatexDynamicModel)
        {
          output << "_{t";
          if (lag != 0)
            {
              if (lag > 0)
                output << "+";
              output << lag;
            }
          output << "}";
        }
      return;
    }

283
  int i;
sebastien's avatar
sebastien committed
284
  int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
285
286
287
288
  switch(type)
    {
    case eParameter:
      if (output_type == oMatlabOutsideModel)
sebastien's avatar
sebastien committed
289
        output << "M_.params" << "(" << tsid + 1 << ")";
290
      else
291
        output << "params" << LEFT_ARRAY_SUBSCRIPT(output_type) << tsid + ARRAY_SUBSCRIPT_OFFSET(output_type) << RIGHT_ARRAY_SUBSCRIPT(output_type);
292
293
294
295
      break;

    case eModelLocalVariable:
    case eModFileLocalVariable:
296
      if(output_type==oMatlabDynamicModelSparse || output_type==oMatlabStaticModelSparse)
297
298
299
300
301
        {
          output << "(";
          datatree.local_variables_table[symb_id]->writeOutput(output, output_type,temporary_terms);
          output << ")";
        }
302
      else
sebastien's avatar
sebastien committed
303
        output << datatree.symbol_table.getName(symb_id);
304
305
306
307
308
309
310
      break;

    case eEndogenous:
      switch(output_type)
        {
        case oMatlabDynamicModel:
        case oCDynamicModel:
311
312
          i = datatree.getDynJacobianCol(deriv_id) + ARRAY_SUBSCRIPT_OFFSET(output_type);
          output <<  "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
313
314
315
316
          break;
        case oMatlabStaticModel:
        case oMatlabStaticModelSparse:
        case oCStaticModel:
317
318
          i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
          output <<  "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
319
320
          break;
        case oMatlabDynamicModelSparse:
321
          i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
322
          if (lag > 0)
323
            output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_+" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
324
          else if (lag < 0)
325
            output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_" << lag << ", " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
326
          else
327
            output << "y" << LEFT_ARRAY_SUBSCRIPT(output_type) << "it_, " << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
328
329
          break;
        case oMatlabOutsideModel:
sebastien's avatar
sebastien committed
330
          output << "oo_.steady_state" << "(" << tsid + 1 << ")";
331
          break;
332
333
        default:
          assert(false);
334
335
336
337
        }
      break;

    case eExogenous:
338
      i = tsid + ARRAY_SUBSCRIPT_OFFSET(output_type);
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
      switch(output_type)
        {
        case oMatlabDynamicModel:
        case oMatlabDynamicModelSparse:
          if (lag > 0)
            output <<  "x(it_+" << lag << ", " << i << ")";
          else if (lag < 0)
            output <<  "x(it_" << lag << ", " << i << ")";
          else
            output <<  "x(it_, " << i << ")";
          break;
        case oCDynamicModel:
          if (lag == 0)
            output <<  "x[it_+" << i << "*nb_row_x]";
          else if (lag > 0)
            output <<  "x[it_+" << lag << "+" << i << "*nb_row_x]";
          else
            output <<  "x[it_" << lag << "+" << i << "*nb_row_x]";
          break;
        case oMatlabStaticModel:
        case oMatlabStaticModelSparse:
        case oCStaticModel:
361
          output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
362
363
          break;
        case oMatlabOutsideModel:
364
          assert(lag == 0);
365
366
          output <<  "oo_.exo_steady_state" << "(" << i << ")";
          break;
367
368
        default:
          assert(false);
369
370
371
372
        }
      break;

    case eExogenousDet:
373
      i = tsid + datatree.symbol_table.exo_nbr() + ARRAY_SUBSCRIPT_OFFSET(output_type);
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
      switch(output_type)
        {
        case oMatlabDynamicModel:
        case oMatlabDynamicModelSparse:
          if (lag > 0)
            output <<  "x(it_+" << lag << ", " << i << ")";
          else if (lag < 0)
            output <<  "x(it_" << lag << ", " << i << ")";
          else
            output <<  "x(it_, " << i << ")";
          break;
        case oCDynamicModel:
          if (lag == 0)
            output <<  "x[it_+" << i << "*nb_row_xd]";
          else if (lag > 0)
            output <<  "x[it_+" << lag << "+" << i << "*nb_row_xd]";
          else
            output <<  "x[it_" << lag << "+" << i << "*nb_row_xd]";
          break;
        case oMatlabStaticModel:
        case oMatlabStaticModelSparse:
        case oCStaticModel:
396
          output << "x" << LEFT_ARRAY_SUBSCRIPT(output_type) << i << RIGHT_ARRAY_SUBSCRIPT(output_type);
397
398
          break;
        case oMatlabOutsideModel:
399
          assert(lag == 0);
sebastien's avatar
sebastien committed
400
          output <<  "oo_.exo_det_steady_state" << "(" << tsid + 1 << ")";
401
          break;
402
403
        default:
          assert(false);
404
405
406
407
408
        }
      break;

    case eUnknownFunction:
      cerr << "Impossible case" << endl;
409
      exit(EXIT_FAILURE);
410
411
412
413
414
415
    }
}

double
VariableNode::eval(const eval_context_type &eval_context) const throw (EvalException)
{
sebastien's avatar
sebastien committed
416
  eval_context_type::const_iterator it = eval_context.find(symb_id);
417
  if (it == eval_context.end())
sebastien's avatar
sebastien committed
418
    throw EvalException();
419
420
421
422
423

  return it->second;
}

void
sebastien's avatar
sebastien committed
424
VariableNode::compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const
425
426
427
428
429
430
431
432
433
434
435
{
  int i, lagl;
#ifdef DEBUGC
  cout << "output_type=" << output_type << "\n";
#endif
  if(!lhs_rhs)
    CompileCode.write(&FLDV, sizeof(FLDV));
  else
    CompileCode.write(&FSTPV, sizeof(FSTPV));
  char typel=(char)type;
  CompileCode.write(&typel, sizeof(typel));
436
  int tsid = datatree.symbol_table.getTypeSpecificID(symb_id);
437
438
439
  switch(type)
    {
    case eParameter:
sebastien's avatar
sebastien committed
440
      i = tsid;
441
442
443
444
445
446
      CompileCode.write(reinterpret_cast<char *>(&i), sizeof(i));
#ifdef DEBUGC
      cout << "FLD Param[ " << i << ", symb_id=" << symb_id << "]\n";
#endif
      break;
    case eEndogenous :
sebastien's avatar
sebastien committed
447
      i = symb_id;
448
449
450
451
452
      CompileCode.write(reinterpret_cast<char *>(&i), sizeof(i));
      lagl=lag;
      CompileCode.write(reinterpret_cast<char *>(&lagl), sizeof(lagl));
      break;
    case eExogenous :
sebastien's avatar
sebastien committed
453
      i = tsid;
454
455
456
457
458
      CompileCode.write(reinterpret_cast<char *>(&i), sizeof(i));
      lagl=lag;
      CompileCode.write(reinterpret_cast<char *>(&lagl), sizeof(lagl));
      break;
    case eExogenousDet:
sebastien's avatar
sebastien committed
459
      i = tsid + datatree.symbol_table.exo_nbr();
460
461
462
463
464
465
      CompileCode.write(reinterpret_cast<char *>(&i), sizeof(i));
      lagl=lag;
      CompileCode.write(reinterpret_cast<char *>(&lagl), sizeof(lagl));
      break;
    case eModelLocalVariable:
    case eModFileLocalVariable:
sebastien's avatar
sebastien committed
466
      datatree.local_variables_table[symb_id]->compile(CompileCode, lhs_rhs, temporary_terms, map_idx);
467
      break;
468
    case eUnknownFunction:
469
      cerr << "Impossible case: eUnknownFuncion" << endl;
470
      exit(EXIT_FAILURE);
471
472
473
    }
}

474
475
void
VariableNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
476
477
478
479
480
481
                                    temporary_terms_type &temporary_terms,
                                    map<NodeID, pair<int, int> > &first_occurence,
                                    int Curr_block,
                                    Model_Block *ModelBlock,
                                    int equation,
                                    map_idx_type &map_idx) const
482
483
484
485
486
{
  if(type== eModelLocalVariable)
    datatree.local_variables_table[symb_id]->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, ModelBlock, equation, map_idx);
}

487
void
488
VariableNode::collectEndogenous(set<pair<int, int> > &result) const
489
490
{
  if (type == eEndogenous)
491
    result.insert(make_pair(symb_id, lag));
492
493
  else if (type == eModelLocalVariable)
    datatree.local_variables_table[symb_id]->collectEndogenous(result);
494
495
}

ferhat's avatar
ferhat committed
496
497
498
499
500
void
VariableNode::collectExogenous(set<pair<int, int> > &result) const
{
  if (type == eExogenous)
    result.insert(make_pair(symb_id, lag));
501
502
  else if (type == eModelLocalVariable)
    datatree.local_variables_table[symb_id]->collectExogenous(result);
ferhat's avatar
ferhat committed
503
504
}

sebastien's avatar
sebastien committed
505
506
507
508
509
510
NodeID
VariableNode::toStatic(DataTree &static_datatree) const
{
  return static_datatree.AddVariable(datatree.symbol_table.getName(symb_id));
}

ferhat's avatar
ferhat committed
511

512
513
514
515
516
517
518
519
520
521
522
523
524
UnaryOpNode::UnaryOpNode(DataTree &datatree_arg, UnaryOpcode op_code_arg, const NodeID arg_arg) :
  ExprNode(datatree_arg),
  arg(arg_arg),
  op_code(op_code_arg)
{
  // Add myself to the unary op map
  datatree.unary_op_node_map[make_pair(arg, op_code)] = this;

  // Non-null derivatives are those of the argument
  non_null_derivatives = arg->non_null_derivatives;
}

NodeID
525
UnaryOpNode::computeDerivative(int deriv_id)
526
{
527
  NodeID darg = arg->getDerivative(deriv_id);
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566

  NodeID t11, t12, t13;

  switch(op_code)
    {
    case oUminus:
      return datatree.AddUMinus(darg);
    case oExp:
      return datatree.AddTimes(darg, this);
    case oLog:
      return datatree.AddDivide(darg, arg);
    case oLog10:
      t11 = datatree.AddExp(datatree.One);
      t12 = datatree.AddLog10(t11);
      t13 = datatree.AddDivide(darg, arg);
      return datatree.AddTimes(t12, t13);
    case oCos:
      t11 = datatree.AddSin(arg);
      t12 = datatree.AddUMinus(t11);
      return datatree.AddTimes(darg, t12);
    case oSin:
      t11 = datatree.AddCos(arg);
      return datatree.AddTimes(darg, t11);
    case oTan:
      t11 = datatree.AddTimes(this, this);
      t12 = datatree.AddPlus(t11, datatree.One);
      return datatree.AddTimes(darg, t12);
    case oAcos:
      t11 = datatree.AddSin(this);
      t12 = datatree.AddDivide(darg, t11);
      return datatree.AddUMinus(t12);
    case oAsin:
      t11 = datatree.AddCos(this);
      return datatree.AddDivide(darg, t11);
    case oAtan:
      t11 = datatree.AddTimes(arg, arg);
      t12 = datatree.AddPlus(datatree.One, t11);
      return datatree.AddDivide(darg, t12);
    case oCosh:
sebastien's avatar
sebastien committed
567
      t11 = datatree.AddSinh(arg);
568
569
      return datatree.AddTimes(darg, t11);
    case oSinh:
sebastien's avatar
sebastien committed
570
      t11 = datatree.AddCosh(arg);
571
572
573
574
575
576
      return datatree.AddTimes(darg, t11);
    case oTanh:
      t11 = datatree.AddTimes(this, this);
      t12 = datatree.AddMinus(datatree.One, t11);
      return datatree.AddTimes(darg, t12);
    case oAcosh:
sebastien's avatar
sebastien committed
577
      t11 = datatree.AddSinh(this);
578
579
      return datatree.AddDivide(darg, t11);
    case oAsinh:
sebastien's avatar
sebastien committed
580
      t11 = datatree.AddCosh(this);
581
582
583
584
585
586
587
588
589
      return datatree.AddDivide(darg, t11);
    case oAtanh:
      t11 = datatree.AddTimes(arg, arg);
      t12 = datatree.AddMinus(datatree.One, t11);
      return datatree.AddTimes(darg, t12);
    case oSqrt:
      t11 = datatree.AddPlus(this, this);
      return datatree.AddDivide(darg, t11);
    }
sebastien's avatar
sebastien committed
590
  // Suppress GCC warning
591
  exit(EXIT_FAILURE);
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
}

int
UnaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
{
  // For a temporary term, the cost is null
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<UnaryOpNode *>(this));
  if (it != temporary_terms.end())
    return 0;

  int cost = arg->cost(temporary_terms, is_matlab);

  if (is_matlab)
    // Cost for Matlab files
    switch(op_code)
      {
      case oUminus:
        return cost + 70;
      case oExp:
        return cost + 160;
      case oLog:
        return cost + 300;
      case oLog10:
        return cost + 16000;
      case oCos:
      case oSin:
      case oCosh:
        return cost + 210;
      case oTan:
        return cost + 230;
      case oAcos:
        return cost + 300;
      case oAsin:
        return cost + 310;
      case oAtan:
        return cost + 140;
      case oSinh:
        return cost + 240;
      case oTanh:
        return cost + 190;
      case oAcosh:
        return cost + 770;
      case oAsinh:
        return cost + 460;
      case oAtanh:
        return cost + 350;
      case oSqrt:
        return cost + 570;
      }
  else
    // Cost for C files
    switch(op_code)
      {
      case oUminus:
        return cost + 3;
      case oExp:
      case oAcosh:
        return cost + 210;
      case oLog:
        return cost + 137;
      case oLog10:
        return cost + 139;
      case oCos:
      case oSin:
        return cost + 160;
      case oTan:
        return cost + 170;
      case oAcos:
      case oAtan:
        return cost + 190;
      case oAsin:
        return cost + 180;
      case oCosh:
      case oSinh:
      case oTanh:
        return cost + 240;
      case oAsinh:
        return cost + 220;
      case oAtanh:
        return cost + 150;
      case oSqrt:
        return cost + 90;
      }
sebastien's avatar
sebastien committed
675
  // Suppress GCC warning
676
  exit(EXIT_FAILURE);
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
}

void
UnaryOpNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                   temporary_terms_type &temporary_terms,
                                   bool is_matlab) const
{
  NodeID this2 = const_cast<UnaryOpNode *>(this);

  map<NodeID, int>::iterator it = reference_count.find(this2);
  if (it == reference_count.end())
    {
      reference_count[this2] = 1;
      arg->computeTemporaryTerms(reference_count, temporary_terms, is_matlab);
    }
  else
    {
      reference_count[this2]++;
      if (reference_count[this2] * cost(temporary_terms, is_matlab) > MIN_COST(is_matlab))
        temporary_terms.insert(this2);
    }
}

void
UnaryOpNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                   temporary_terms_type &temporary_terms,
703
                                   map<NodeID, pair<int, int> > &first_occurence,
704
705
                                   int Curr_block,
                                   Model_Block *ModelBlock,
706
                                   int equation,
707
708
709
710
711
712
713
                                   map_idx_type &map_idx) const
{
  NodeID this2 = const_cast<UnaryOpNode *>(this);
  map<NodeID, int>::iterator it = reference_count.find(this2);
  if (it == reference_count.end())
    {
      reference_count[this2] = 1;
714
715
      first_occurence[this2] = make_pair(Curr_block,equation);
      arg->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, ModelBlock, equation, map_idx);
716
717
718
719
720
721
722
    }
  else
    {
      reference_count[this2]++;
      if (reference_count[this2] * cost(temporary_terms, false) > MIN_COST_C)
        {
          temporary_terms.insert(this2);
723
          ModelBlock->Block_List[first_occurence[this2].first].Temporary_Terms_in_Equation[first_occurence[this2].second]->insert(this2);
724
725
726
727
        }
    }
}

728
729
730
731
732
733
734
735
736
737
void
UnaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<UnaryOpNode*>(this));
  if (it != temporary_terms.end())
    ModelBlock->Block_List[Curr_Block].Temporary_InUse->insert(idx);
  else
    arg->collectTemporary_terms(temporary_terms, ModelBlock, Curr_Block);
}

738
739
740
741
742
743
744
745
void
UnaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
                         const temporary_terms_type &temporary_terms) const
{
  // If node is a temporary term
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<UnaryOpNode *>(this));
  if (it != temporary_terms.end())
    {
sebastien's avatar
sebastien committed
746
      if (output_type == oMatlabDynamicModelSparse)
747
748
749
750
751
752
753
754
        output << "T" << idx << "(it_)";
      else
        output << "T" << idx;
      return;
    }

  // Always put parenthesis around uminus nodes
  if (op_code == oUminus)
755
    output << LEFT_PAR(output_type);
756
757
758
759
760
761
762
763
764
765
766
767
768

  switch(op_code)
    {
    case oUminus:
      output << "-";
      break;
    case oExp:
      output << "exp";
      break;
    case oLog:
      output << "log";
      break;
    case oLog10:
769
770
771
772
      if (IS_LATEX(output_type))
        output << "log_{10}";
      else
        output << "log10";
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
      break;
    case oCos:
      output << "cos";
      break;
    case oSin:
      output << "sin";
      break;
    case oTan:
      output << "tan";
      break;
    case oAcos:
      output << "acos";
      break;
    case oAsin:
      output << "asin";
      break;
    case oAtan:
      output << "atan";
      break;
    case oCosh:
      output << "cosh";
      break;
    case oSinh:
      output << "sinh";
      break;
    case oTanh:
      output << "tanh";
      break;
    case oAcosh:
      output << "acosh";
      break;
    case oAsinh:
      output << "asinh";
      break;
    case oAtanh:
      output << "atanh";
      break;
    case oSqrt:
      output << "sqrt";
      break;
    }

  bool close_parenthesis = false;

  /* Enclose argument with parentheses if:
     - current opcode is not uminus, or
     - current opcode is uminus and argument has lowest precedence
  */
  if (op_code != oUminus
      || (op_code == oUminus
          && arg->precedence(output_type, temporary_terms) < precedence(output_type, temporary_terms)))
    {
825
      output << LEFT_PAR(output_type);
826
827
828
829
830
831
832
      close_parenthesis = true;
    }

  // Write argument
  arg->writeOutput(output, output_type, temporary_terms);

  if (close_parenthesis)
833
    output << RIGHT_PAR(output_type);
834
835
836

  // Close parenthesis for uminus
  if (op_code == oUminus)
837
    output << RIGHT_PAR(output_type);
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
}

double
UnaryOpNode::eval_opcode(UnaryOpcode op_code, double v) throw (EvalException)
{
  switch(op_code)
    {
    case oUminus:
      return(-v);
    case oExp:
      return(exp(v));
    case oLog:
      return(log(v));
    case oLog10:
      return(log10(v));
    case oCos:
      return(cos(v));
    case oSin:
      return(sin(v));
    case oTan:
      return(tan(v));
    case oAcos:
      return(acos(v));
    case oAsin:
      return(asin(v));
    case oAtan:
      return(atan(v));
    case oCosh:
      return(cosh(v));
    case oSinh:
      return(sinh(v));
    case oTanh:
      return(tanh(v));
    case oAcosh:
      return(acosh(v));
    case oAsinh:
      return(asinh(v));
    case oAtanh:
      return(atanh(v));
    case oSqrt:
      return(sqrt(v));
    }
sebastien's avatar
sebastien committed
880
881
  // Suppress GCC warning
  exit(EXIT_FAILURE);
882
883
884
885
886
887
888
889
890
891
892
}

double
UnaryOpNode::eval(const eval_context_type &eval_context) const throw (EvalException)
{
  double v = arg->eval(eval_context);

  return eval_opcode(op_code, v);
}

void
sebastien's avatar
sebastien committed
893
UnaryOpNode::compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const
894
895
896
897
898
899
900
901
902
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<UnaryOpNode *>(this));
  if (it != temporary_terms.end())
    {
      CompileCode.write(&FLDT, sizeof(FLDT));
      int var=map_idx[idx];
      CompileCode.write(reinterpret_cast<char *>(&var), sizeof(var));
      return;
    }
sebastien's avatar
sebastien committed
903
  arg->compile(CompileCode, lhs_rhs, temporary_terms, map_idx);
904
905
906
907
908
909
  CompileCode.write(&FUNARY, sizeof(FUNARY));
  UnaryOpcode op_codel=op_code;
  CompileCode.write(reinterpret_cast<char *>(&op_codel), sizeof(op_codel));
}

void
910
UnaryOpNode::collectEndogenous(set<pair<int, int> > &result) const
911
{
912
  arg->collectEndogenous(result);
913
914
}

ferhat's avatar
ferhat committed
915
916
917
918
919
920
void
UnaryOpNode::collectExogenous(set<pair<int, int> > &result) const
{
  arg->collectExogenous(result);
}

sebastien's avatar
sebastien committed
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
NodeID
UnaryOpNode::toStatic(DataTree &static_datatree) const
{
  NodeID sarg = arg->toStatic(static_datatree);
  switch(op_code)
    {
    case oUminus:
      return static_datatree.AddUMinus(sarg);
    case oExp:
      return static_datatree.AddExp(sarg);
    case oLog:
      return static_datatree.AddLog(sarg);
    case oLog10:
      return static_datatree.AddLog10(sarg);
    case oCos:
      return static_datatree.AddCos(sarg);
    case oSin:
      return static_datatree.AddSin(sarg);
    case oTan:
      return static_datatree.AddTan(sarg);
    case oAcos:
sebastien's avatar
sebastien committed
942
      return static_datatree.AddAcos(sarg);
sebastien's avatar
sebastien committed
943
    case oAsin:
sebastien's avatar
sebastien committed
944
      return static_datatree.AddAsin(sarg);
sebastien's avatar
sebastien committed
945
    case oAtan:
sebastien's avatar
sebastien committed
946
      return static_datatree.AddAtan(sarg);
sebastien's avatar
sebastien committed
947
    case oCosh:
sebastien's avatar
sebastien committed
948
      return static_datatree.AddCosh(sarg);
sebastien's avatar
sebastien committed
949
    case oSinh:
sebastien's avatar
sebastien committed
950
      return static_datatree.AddSinh(sarg);
sebastien's avatar
sebastien committed
951
    case oTanh:
sebastien's avatar
sebastien committed
952
      return static_datatree.AddTanh(sarg);
sebastien's avatar
sebastien committed
953
    case oAcosh:
sebastien's avatar
sebastien committed
954
      return static_datatree.AddAcosh(sarg);
sebastien's avatar
sebastien committed
955
    case oAsinh:
sebastien's avatar
sebastien committed
956
      return static_datatree.AddAsinh(sarg);
sebastien's avatar
sebastien committed
957
    case oAtanh:
sebastien's avatar
sebastien committed
958
      return static_datatree.AddAtanh(sarg);
sebastien's avatar
sebastien committed
959
    case oSqrt:
sebastien's avatar
sebastien committed
960
      return static_datatree.AddSqrt(sarg);
sebastien's avatar
sebastien committed
961
    }
sebastien's avatar
sebastien committed
962
963
  // Suppress GCC warning
  exit(EXIT_FAILURE);
sebastien's avatar
sebastien committed
964
965
}

ferhat's avatar
ferhat committed
966

967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
BinaryOpNode::BinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
                           BinaryOpcode op_code_arg, const NodeID arg2_arg) :
  ExprNode(datatree_arg),
  arg1(arg1_arg),
  arg2(arg2_arg),
  op_code(op_code_arg)
{
  datatree.binary_op_node_map[make_pair(make_pair(arg1, arg2), op_code)] = this;

  // Non-null derivatives are the union of those of the arguments
  // Compute set union of arg1->non_null_derivatives and arg2->non_null_derivatives
  set_union(arg1->non_null_derivatives.begin(),
            arg1->non_null_derivatives.end(),
            arg2->non_null_derivatives.begin(),
            arg2->non_null_derivatives.end(),
            inserter(non_null_derivatives, non_null_derivatives.begin()));
}

NodeID
986
BinaryOpNode::computeDerivative(int deriv_id)
987
{
988
989
  NodeID darg1 = arg1->getDerivative(deriv_id);
  NodeID darg2 = arg2->getDerivative(deriv_id);
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052

  NodeID t11, t12, t13, t14, t15;

  switch(op_code)
    {
    case oPlus:
      return datatree.AddPlus(darg1, darg2);
    case oMinus:
      return datatree.AddMinus(darg1, darg2);
    case oTimes:
      t11 = datatree.AddTimes(darg1, arg2);
      t12 = datatree.AddTimes(darg2, arg1);
      return datatree.AddPlus(t11, t12);
    case oDivide:
      t11 = datatree.AddTimes(darg1, arg2);
      t12 = datatree.AddTimes(darg2, arg1);
      t13 = datatree.AddMinus(t11, t12);
      t14 = datatree.AddTimes(arg2, arg2);
      return datatree.AddDivide(t13, t14);
    case oLess:
    case oGreater:
    case oLessEqual:
    case oGreaterEqual:
    case oEqualEqual:
    case oDifferent:
      return datatree.Zero;
    case oPower:
      if (darg2 == datatree.Zero)
        {
          if (darg1 == datatree.Zero)
            return datatree.Zero;
          else
            {
              t11 = datatree.AddMinus(arg2, datatree.One);
              t12 = datatree.AddPower(arg1, t11);
              t13 = datatree.AddTimes(arg2, t12);
              return datatree.AddTimes(darg1, t13);
            }
        }
      else
        {
          t11 = datatree.AddLog(arg1);
          t12 = datatree.AddTimes(darg2, t11);
          t13 = datatree.AddTimes(darg1, arg2);
          t14 = datatree.AddDivide(t13, arg1);
          t15 = datatree.AddPlus(t12, t14);
          return datatree.AddTimes(t15, this);
        }
    case oMax:
      t11 = datatree.AddGreater(arg1,arg2);
      t12 = datatree.AddTimes(t11,darg1);
      t13 = datatree.AddMinus(datatree.One,t11);
      t14 = datatree.AddTimes(t13,darg2);
      return datatree.AddPlus(t14,t12);
    case oMin:
      t11 = datatree.AddGreater(arg2,arg1);
      t12 = datatree.AddTimes(t11,darg1);
      t13 = datatree.AddMinus(datatree.One,t11);
      t14 = datatree.AddTimes(t13,darg2);
      return datatree.AddPlus(t14,t12);
    case oEqual:
      return datatree.AddMinus(darg1, darg2);
    }
sebastien's avatar
sebastien committed
1053
  // Suppress GCC warning
1054
  exit(EXIT_FAILURE);
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
}

int
BinaryOpNode::precedence(ExprNodeOutputType output_type, const temporary_terms_type &temporary_terms) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
  // A temporary term behaves as a variable
  if (it != temporary_terms.end())
    return 100;

  switch(op_code)
    {
    case oEqual:
      return 0;
    case oEqualEqual:
    case oDifferent:
      return 1;
    case oLessEqual:
    case oGreaterEqual:
    case oLess:
    case oGreater:
      return 2;
    case oPlus:
    case oMinus:
      return 3;
    case oTimes:
    case oDivide:
      return 4;
    case oPower:
1084
      if (IS_C(output_type))
1085
1086
1087
1088
1089
1090
1091
1092
        // In C, power operator is of the form pow(a, b)
        return 100;
      else
        return 5;
    case oMin:
    case oMax:
      return 100;
    }
sebastien's avatar
sebastien committed
1093
  // Suppress GCC warning
1094
  exit(EXIT_FAILURE);
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
}

int
BinaryOpNode::cost(const temporary_terms_type &temporary_terms, bool is_matlab) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
  // For a temporary term, the cost is null
  if (it != temporary_terms.end())
    return 0;

  int cost = arg1->cost(temporary_terms, is_matlab);
  cost += arg2->cost(temporary_terms, is_matlab);

  if (is_matlab)
    // Cost for Matlab files
    switch(op_code)
      {
      case oLess:
      case oGreater:
      case oLessEqual:
      case oGreaterEqual:
      case oEqualEqual:
      case oDifferent:
        return cost + 60;
      case oPlus:
      case oMinus:
      case oTimes:
        return cost + 90;
      case oMax:
      case oMin:
1125
        return cost + 110;
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
      case oDivide:
        return cost + 990;
      case oPower:
        return cost + 1160;
      case oEqual:
        return cost;
      }
  else
    // Cost for C files
    switch(op_code)
      {
      case oLess:
      case oGreater:
      case oLessEqual:
      case oGreaterEqual:
      case oEqualEqual:
      case oDifferent:
        return cost + 2;
      case oPlus:
      case oMinus:
      case oTimes:
        return cost + 4;
      case oMax:
      case oMin:
	return cost + 5;
      case oDivide:
        return cost + 15;
      case oPower:
        return cost + 520;
      case oEqual:
        return cost;
      }
sebastien's avatar
sebastien committed
1158
  // Suppress GCC warning
1159
  exit(EXIT_FAILURE);
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
}

void
BinaryOpNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                    temporary_terms_type &temporary_terms,
                                    bool is_matlab) const
{
  NodeID this2 = const_cast<BinaryOpNode *>(this);
  map<NodeID, int>::iterator it = reference_count.find(this2);
  if (it == reference_count.end())
    {
      // If this node has never been encountered, set its ref count to one,
      //  and travel through its children
      reference_count[this2] = 1;
      arg1->computeTemporaryTerms(reference_count, temporary_terms, is_matlab);
      arg2->computeTemporaryTerms(reference_count, temporary_terms, is_matlab);
    }
  else
    {
      // If the node has already been encountered, increment its ref count
      //  and declare it as a temporary term if it is too costly
      reference_count[this2]++;
      if (reference_count[this2] * cost(temporary_terms, is_matlab) > MIN_COST(is_matlab))
        temporary_terms.insert(this2);
    }
}

void
BinaryOpNode::computeTemporaryTerms(map<NodeID, int> &reference_count,
                                    temporary_terms_type &temporary_terms,
1190
                                    map<NodeID, pair<int, int> > &first_occurence,
1191
1192
                                    int Curr_block,
                                    Model_Block *ModelBlock,
1193
                                    int equation,
1194
1195
1196
1197
1198
1199
1200
                                    map_idx_type &map_idx) const
{
  NodeID this2 = const_cast<BinaryOpNode *>(this);
  map<NodeID, int>::iterator it = reference_count.find(this2);
  if (it == reference_count.end())
    {
      reference_count[this2] = 1;
1201
1202
1203
      first_occurence[this2] = make_pair(Curr_block, equation);
      arg1->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, ModelBlock, equation, map_idx);
      arg2->computeTemporaryTerms(reference_count, temporary_terms, first_occurence, Curr_block, ModelBlock, equation, map_idx);
1204
1205
1206
1207
1208
1209
1210
    }
  else
    {
      reference_count[this2]++;
      if (reference_count[this2] * cost(temporary_terms, false) > MIN_COST_C)
        {
          temporary_terms.insert(this2);
1211
          ModelBlock->Block_List[first_occurence[this2].first].Temporary_Terms_in_Equation[first_occurence[this2].second]->insert(this2);
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        }
    }
}

double
BinaryOpNode::eval_opcode(double v1, BinaryOpcode op_code, double v2) throw (EvalException)
{
  switch(op_code)
    {
    case oPlus:
      return(v1 + v2);
    case oMinus:
      return(v1 - v2);
    case oTimes:
      return(v1 * v2);
    case oDivide:
      return(v1 / v2);
    case oPower:
      return(pow(v1, v2));
    case oMax:
      if (v1 < v2)
        return v2;
      else
        return v1;
    case oMin:
      if (v1 > v2)
        return v2;
      else
        return v1;
    case oLess:
      return (v1 < v2);
    case oGreater:
      return (v1 > v2);
    case oLessEqual:
      return (v1 <= v2);
    case oGreaterEqual:
      return (v1 >= v2);
    case oEqualEqual:
      return (v1 == v2);
    case oDifferent:
      return (v1 != v2);
    case oEqual:
      throw EvalException();
    }
sebastien's avatar
sebastien committed
1256
  // Suppress GCC warning
1257
  exit(EXIT_FAILURE);
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
}

double
BinaryOpNode::eval(const eval_context_type &eval_context) const throw (EvalException)
{
  double v1 = arg1->eval(eval_context);
  double v2 = arg2->eval(eval_context);

  return eval_opcode(v1, op_code, v2);
}

void
sebastien's avatar
sebastien committed
1270
BinaryOpNode::compile(ofstream &CompileCode, bool lhs_rhs, const temporary_terms_type &temporary_terms, map_idx_type &map_idx) const
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
{
  // If current node is a temporary term
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
  if (it != temporary_terms.end())
    {
      CompileCode.write(&FLDT, sizeof(FLDT));
      int var=map_idx[idx];
      CompileCode.write(reinterpret_cast<char *>(&var), sizeof(var));
      return;
    }
sebastien's avatar
sebastien committed
1281
1282
  arg1->compile(CompileCode, lhs_rhs, temporary_terms, map_idx);
  arg2->compile(CompileCode, lhs_rhs, temporary_terms, map_idx);
1283
1284
1285
1286
1287
  CompileCode.write(&FBINARY, sizeof(FBINARY));
  BinaryOpcode op_codel=op_code;
  CompileCode.write(reinterpret_cast<char *>(&op_codel),sizeof(op_codel));
}

1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
void
BinaryOpNode::collectTemporary_terms(const temporary_terms_type &temporary_terms, Model_Block *ModelBlock, int Curr_Block) const
{
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
  if (it != temporary_terms.end())
    ModelBlock->Block_List[Curr_Block].Temporary_InUse->insert(idx);
  else
    {
      arg1->collectTemporary_terms(temporary_terms, ModelBlock, Curr_Block);
      arg2->collectTemporary_terms(temporary_terms, ModelBlock, Curr_Block);
    }
}


1302
1303
1304
1305
1306
1307
1308
1309
void
BinaryOpNode::writeOutput(ostream &output, ExprNodeOutputType output_type,
                          const temporary_terms_type &temporary_terms) const
{
  // If current node is a temporary term
  temporary_terms_type::const_iterator it = temporary_terms.find(const_cast<BinaryOpNode *>(this));
  if (it != temporary_terms.end())
    {
sebastien's avatar
sebastien committed
1310
      if (output_type == oMatlabDynamicModelSparse)
1311
1312
1313
1314
1315
1316
1317
        output << "T" << idx << "(it_)";
      else
        output << "T" << idx;
      return;
    }

  // Treat special case of power operator in C, and case of max and min operators
1318
  if ((op_code == oPower && IS_C(output_type)) || op_code == oMax || op_code == oMin )
1319
1320
    {
      switch (op_code)
1321
        {
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
        case oPower:
          output << "pow(";
          break;
        case oMax:
          output << "max(";
          break;
        case oMin:
          output << "min(";
          break;
        default:
          ;
1333
        }
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
      arg1->writeOutput(output, output_type, temporary_terms);
      output << ",";
      arg2->writeOutput(output, output_type, temporary_terms);
      output << ")";
      return;
    }

  int prec = precedence(output_type, temporary_terms);

  bool close_parenthesis = false;

1345
1346
1347
  if (IS_LATEX(output_type) && op_code == oDivide)
    output << "\\frac{";
  else
1348
    {
1349
1350
1351
1352
1353
1354
1355
1356
      // If left argument has a lower precedence, or if current and left argument are both power operators, add parenthesis around left argument
      BinaryOpNode *barg1 = dynamic_cast<BinaryOpNode *>(arg1);
      if (arg1->precedence(output_type, temporary_terms) < prec
          || (op_code == oPower && barg1 != NULL && barg1->op_code == oPower))
        {
          output << LEFT_PAR(output_type);
          close_parenthesis = true;
        }
1357
1358
1359
1360
1361
1362
    }

  // Write left argument
  arg1->writeOutput(output, output_type, temporary_terms);

  if (close_parenthesis)
1363
1364
1365
1366
1367
    output << RIGHT_PAR(output_type);

  if (IS_LATEX(output_type) && op_code == oDivide)
    output << "}";

1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378

  // Write current operator symbol
  switch(op_code)
    {
    case oPlus:
      output << "+";
      break;
    case oMinus:
      output << "-";
      break;
    case oTimes:
1379
1380
1381
1382
      if (IS_LATEX(output_type))
        output << "\\cdot ";
      else
        output << "*";
1383
1384
      break;
    case oDivide:
1385
1386
      if (!IS_LATEX(output_type))
        output << "/";
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
      break;
    case oPower:
      output << "^";
      break;
    case oLess:
      output << "<";
      break;
    case oGreater:
      output << ">";
      break;
    case oLessEqual:
1398
1399
1400
1401
      if (IS_LATEX(output_type))
        output << "\\leq ";
      else
        output << "<=";
1402
1403
      break;
    case oGreaterEqual:
1404
1405
1406
1407
      if (IS_LATEX(output_type))
        output << "\\geq ";
      else
        output << ">=";
1408
1409
1410
1411
1412
      break;
    case oEqualEqual:
      output << "==";
      break;
    case oDifferent:
1413
      if (IS_MATLAB(output_type))
1414
1415
        output << "~=";
      else
1416
1417
1418
1419
1420
1421
        {
          if (IS_C(output_type))
            output << "!=";
          else
            output << "\\neq ";
        }
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
      break;
    case oEqual:
      output << "=";
      break;
    default:
      ;
    }

  close_parenthesis = false;

1432
1433
1434
  if (IS_LATEX(output_type) && (op_code == oPower || op_code == oDivide))
    output << "{";
  else
1435
    {
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
      /* Add parenthesis around right argument if:
         - its precedence is lower than those of the current node
         - it is a power operator and current operator is also a power operator
         - it is a minus operator with same precedence than current operator
         - it is a divide operator with same precedence than current operator */
      BinaryOpNode *barg2 = dynamic_cast<BinaryOpNode *>(arg2);
      int arg2_prec = arg2->precedence(output_type, temporary_terms);
      if (arg2_prec < prec
          || (op_code == oPower && barg2 != NULL && barg2->op_code == oPower && !IS_LATEX(output_type))
          || (op_code == oMinus && arg2_prec == prec)
          || (op_code == oDivide && arg2_prec == prec && !IS_LATEX(output_type)))
        {
          output << LEFT_PAR(output_type);
          close_parenthesis = true;
        }
1451
1452
1453
1454
1455
    }

  // Write right argument
  arg2->writeOutput(output, output_type, temporary_terms);

1456
1457
1458
  if (IS_LATEX(output_type) && (op_code == oPower || op_code == oDivide))
    output << "}";

1459
  if (close_parenthesis)
1460
    output << RIGHT_PAR(output_type);
1461
1462
1463
}

void
1464
BinaryOpNode::collectEndogenous(set<pair<int, int> > &result) const
1465
{
1466
1467
  arg1->collectEndogenous(result);
  arg2->collectEndogenous(result);
1468
1469
}

ferhat's avatar
ferhat committed
1470
1471
1472
1473
1474
1475
1476
void
BinaryOpNode::collectExogenous(set<pair<int, int> > &result) const
{
  arg1->collectExogenous(result);
  arg2->collectExogenous(result);
}

sebastien's avatar
sebastien committed
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
NodeID
BinaryOpNode::toStatic(DataTree &static_datatree) const
{
  NodeID sarg1 = arg1->toStatic(static_datatree);
  NodeID sarg2 = arg2->toStatic(static_datatree);
  switch(op_code)
    {
    case oPlus:
      return static_datatree.AddPlus(sarg1, sarg2);
    case oMinus:
      return static_datatree.AddMinus(sarg1, sarg2);
    case oTimes:
      return static_datatree.AddTimes(sarg1, sarg2);
    case oDivide:
      return static_datatree.AddDivide(sarg1, sarg2);
    case oPower:
      return static_datatree.AddPower(sarg1, sarg2);
    case oEqual:
      return static_datatree.AddEqual(sarg1, sarg2);
    case oMax:
sebastien's avatar
sebastien committed
1497
      return static_datatree.AddMax(sarg1, sarg2);
sebastien's avatar
sebastien committed
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
    case oMin:
      return static_datatree.AddMin(sarg1, sarg2);
    case oLess:
      return static_datatree.AddLess(sarg1, sarg2);
    case oGreater:
      return static_datatree.AddGreater(sarg1, sarg2);
    case oLessEqual:
      return static_datatree.AddLessEqual(sarg1, sarg2);
    case oGreaterEqual:
      return static_datatree.AddGreaterEqual(sarg1, sarg2);
    case oEqualEqual:
      return static_datatree.AddEqualEqual(sarg1, sarg2);
    case oDifferent:
      return static_datatree.AddDifferent(sarg1, sarg2);
    }
sebastien's avatar
sebastien committed
1513
1514
  // Suppress GCC warning
  exit(EXIT_FAILURE);
sebastien's avatar
sebastien committed
1515
1516
1517
}


1518
TrinaryOpNode::TrinaryOpNode(DataTree &datatree_arg, const NodeID arg1_arg,
1519
                             TrinaryOpcode op_code_arg, const NodeID arg2_arg, const NodeID arg3_arg) :
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
  ExprNode(datatree_arg),
  arg1(arg1_arg),
  arg2(arg2_arg),
  arg3(arg3_arg),
  op_code(op_code_arg)
{
  datatree.trinary_op_node_map[make_pair(make_pair(make_pair(arg1, arg2), arg3), op_code)] = this;

  // Non-null derivatives are the union of those of the arguments
  // Compute set union of arg{1,2,3}->non_null_derivatives
ferhat's avatar
ferhat committed
1530
  set<int> non_null_derivatives_tmp;
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
  set_union(arg1->non_null_derivatives.begin(),
            arg1->non_null_derivatives.end(),
            arg2->non_null_derivatives.begin(),
            arg2->non_null_derivatives.end(),
            inserter(non_null_derivatives_tmp, non_null_derivatives_tmp.begin()));
  set_union(non_null_derivatives_tmp.begin(),
            non_null_derivatives_tmp.end(),
            arg3->non_null_derivatives.begin(),
            arg3->non_null_derivatives.end(),
            inserter(non_null_derivatives, non_null_derivatives.begin()));
}

NodeID
1544
TrinaryOpNode::computeDerivative(int deriv_id)
1545
{
1546
1547
1548
  NodeID darg1 = arg1->getDerivative(deriv_id);
  NodeID darg2 = arg2->getDerivative(deriv_id);
  NodeID darg3 = arg3->getDerivative(deriv_id);
1549
1550
1551
1552
1553
1554
1555
1556
1557

  NodeID t11, t12, t13, t14, t15;

  switch(op_code)
    {
    case oNormcdf:
      // normal pdf is inlined in the tree
      NodeID y;
      // sqrt(2*pi)
sebastien's avatar
sebastien committed
1558
      t14 = datatree.AddSqrt(datatree.AddTimes(datatree.Two, datatree.Pi));
1559
1560
1561
1562
1563
1564
1565
1566
1567