Skip to content

Commit bf0a6e0

Browse files
committed
Adds IndexNotationVisitors
1 parent b29b033 commit bf0a6e0

File tree

16 files changed

+136
-135
lines changed

16 files changed

+136
-135
lines changed

include/taco/index_notation/expr_node.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
namespace taco {
1212

1313
class IndexVar;
14-
class ExprVisitorStrict;
14+
class IndexExprVisitorStrict;
15+
class IndexNotationVisitorStrict;
1516
class OperatorSplit;
1617

1718
/// A node of a scalar index expression tree.
@@ -20,7 +21,7 @@ struct ExprNode : public util::Manageable<ExprNode>, private util::Uncopyable {
2021
ExprNode();
2122
ExprNode(DataType type);
2223
virtual ~ExprNode() = default;
23-
virtual void accept(ExprVisitorStrict*) const = 0;
24+
virtual void accept(IndexExprVisitorStrict*) const = 0;
2425

2526
/// Split the expression.
2627
void splitOperator(IndexVar old, IndexVar left, IndexVar right);
@@ -43,7 +44,7 @@ struct TensorExprNode : public util::Manageable<TensorExprNode>,
4344
TensorExprNode();
4445
TensorExprNode(Type type);
4546
virtual ~TensorExprNode() = default;
46-
virtual void accept(ExprVisitorStrict*) const = 0;
47+
virtual void accept(IndexNotationVisitorStrict*) const = 0;
4748

4849
Type getType() const;
4950

include/taco/index_notation/expr_nodes.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ struct AccessNode : public ExprNode {
1717
AccessNode(TensorVar tensorVar, const std::vector<IndexVar>& indices)
1818
: ExprNode(tensorVar.getType().getDataType()), tensorVar(tensorVar), indexVars(indices) {}
1919

20-
void accept(ExprVisitorStrict* v) const {
20+
void accept(IndexExprVisitorStrict* v) const {
2121
v->visit(this);
2222
}
2323

@@ -46,15 +46,15 @@ struct UnaryExprNode : public ExprNode {
4646
struct NegNode : public UnaryExprNode {
4747
NegNode(IndexExpr operand) : UnaryExprNode(operand) {}
4848

49-
void accept(ExprVisitorStrict* v) const {
49+
void accept(IndexExprVisitorStrict* v) const {
5050
v->visit(this);
5151
}
5252
};
5353

5454
struct SqrtNode : public UnaryExprNode {
5555
SqrtNode(IndexExpr operand) : UnaryExprNode(operand) {}
5656

57-
void accept(ExprVisitorStrict* v) const {
57+
void accept(IndexExprVisitorStrict* v) const {
5858
v->visit(this);
5959
}
6060

@@ -80,7 +80,7 @@ struct AddNode : public BinaryExprNode {
8080
return "+";
8181
}
8282

83-
void accept(ExprVisitorStrict* v) const {
83+
void accept(IndexExprVisitorStrict* v) const {
8484
v->visit(this);
8585
}
8686
};
@@ -92,7 +92,7 @@ struct SubNode : public BinaryExprNode {
9292
return "-";
9393
}
9494

95-
void accept(ExprVisitorStrict* v) const {
95+
void accept(IndexExprVisitorStrict* v) const {
9696
v->visit(this);
9797
}
9898
};
@@ -104,7 +104,7 @@ struct MulNode : public BinaryExprNode {
104104
return "*";
105105
}
106106

107-
void accept(ExprVisitorStrict* v) const {
107+
void accept(IndexExprVisitorStrict* v) const {
108108
v->visit(this);
109109
}
110110
};
@@ -116,15 +116,15 @@ struct DivNode : public BinaryExprNode {
116116
return "/";
117117
}
118118

119-
void accept(ExprVisitorStrict* v) const {
119+
void accept(IndexExprVisitorStrict* v) const {
120120
v->visit(this);
121121
}
122122
};
123123

124124
struct ReductionNode : public ExprNode {
125125
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
126126

127-
void accept(ExprVisitorStrict* v) const {
127+
void accept(IndexExprVisitorStrict* v) const {
128128
v->visit(this);
129129
}
130130

@@ -137,7 +137,7 @@ struct ReductionNode : public ExprNode {
137137
struct IntImmNode : public ImmExprNode {
138138
IntImmNode(long long val) : ImmExprNode(Int(sizeof(long long)*8)), val(val) {}
139139

140-
void accept(ExprVisitorStrict* v) const {
140+
void accept(IndexExprVisitorStrict* v) const {
141141
v->visit(this);
142142
}
143143

@@ -147,7 +147,7 @@ struct IntImmNode : public ImmExprNode {
147147
struct UIntImmNode : public ImmExprNode {
148148
UIntImmNode(unsigned long long val) : ImmExprNode(UInt(sizeof(long long)*8)), val(val) {}
149149

150-
void accept(ExprVisitorStrict* v) const {
150+
void accept(IndexExprVisitorStrict* v) const {
151151
v->visit(this);
152152
}
153153

@@ -157,7 +157,7 @@ struct UIntImmNode : public ImmExprNode {
157157
struct ComplexImmNode : public ImmExprNode {
158158
ComplexImmNode(std::complex<double> val) : ImmExprNode(Complex128), val(val){}
159159

160-
void accept(ExprVisitorStrict* v) const {
160+
void accept(IndexExprVisitorStrict* v) const {
161161
v->visit(this);
162162
}
163163

@@ -167,7 +167,7 @@ struct ComplexImmNode : public ImmExprNode {
167167
struct FloatImmNode : public ImmExprNode {
168168
FloatImmNode(double val) : ImmExprNode(Float()), val(val) {}
169169

170-
void accept(ExprVisitorStrict* v) const {
170+
void accept(IndexExprVisitorStrict* v) const {
171171
v->visit(this);
172172
}
173173

@@ -180,7 +180,7 @@ struct AssignmentNode : public TensorExprNode {
180180
AssignmentNode(const Access& lhs, const IndexExpr& rhs, const IndexExpr& op)
181181
: lhs(lhs), rhs(rhs), op(op) {}
182182

183-
void accept(ExprVisitorStrict* v) const {
183+
void accept(IndexNotationVisitorStrict* v) const {
184184
v->visit(this);
185185
}
186186

@@ -193,7 +193,7 @@ struct ForallNode : public TensorExprNode {
193193
ForallNode(IndexVar indexVar, TensorExpr expr)
194194
: indexVar(indexVar), expr(expr) {}
195195

196-
void accept(ExprVisitorStrict* v) const {
196+
void accept(IndexNotationVisitorStrict* v) const {
197197
v->visit(this);
198198
}
199199

@@ -205,7 +205,7 @@ struct WhereNode : public TensorExprNode {
205205
WhereNode(TensorExpr consumer, TensorExpr producer)
206206
: consumer(consumer), producer(producer) {}
207207

208-
void accept(ExprVisitorStrict* v) const {
208+
void accept(IndexNotationVisitorStrict* v) const {
209209
v->visit(this);
210210
}
211211

include/taco/index_notation/expr_printer.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
namespace taco {
88

9-
class ExprPrinter : public ExprVisitorStrict {
9+
class IndexNotationPrinter : public IndexNotationVisitorStrict {
1010
public:
11-
ExprPrinter(std::ostream& os);
11+
IndexNotationPrinter(std::ostream& os);
1212

1313
void print(const IndexExpr& expr);
1414
void print(const TensorExpr& expr);
1515

16-
using ExprVisitorStrict::visit;
16+
using IndexExprVisitorStrict::visit;
1717

1818
// Scalar Expressions
1919
void visit(const AccessNode*);

include/taco/index_notation/expr_rewriter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct ReductionNode;
2323

2424
struct AssignmentNode;
2525

26-
class ExprRewriterStrict : public ExprVisitorStrict {
26+
class ExprRewriterStrict : public IndexExprVisitorStrict {
2727
public:
2828
virtual ~ExprRewriterStrict() {}
2929

@@ -32,7 +32,7 @@ class ExprRewriterStrict : public ExprVisitorStrict {
3232
TensorExpr rewrite(TensorExpr);
3333

3434
protected:
35-
using ExprVisitorStrict::visit;
35+
using IndexExprVisitorStrict::visit;
3636

3737
/// assign to expr in visit methods to replace the visited expr
3838
IndexExpr expr;

include/taco/index_notation/expr_visitor.h

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,11 @@ struct WhereNode;
3131

3232
/// Visit the nodes in an expression. This visitor provides some type safety
3333
/// by requing all visit methods to be overridden.
34-
class ExprVisitorStrict {
34+
class IndexExprVisitorStrict {
3535
public:
36-
virtual ~ExprVisitorStrict();
36+
virtual ~IndexExprVisitorStrict();
3737

3838
void visit(const IndexExpr&);
39-
void visit(const TensorExpr&);
4039

4140
// Scalar Index Expressions
4241
virtual void visit(const AccessNode*) = 0;
@@ -51,22 +50,30 @@ class ExprVisitorStrict {
5150
virtual void visit(const ComplexImmNode*) = 0;
5251
virtual void visit(const UIntImmNode*) = 0;
5352
virtual void visit(const ReductionNode*) = 0;
54-
55-
// Tensor Expressions
56-
virtual void visit(const AssignmentNode*) {}
57-
virtual void visit(const ForallNode*) {}
58-
virtual void visit(const WhereNode*) {}
5953
};
6054

55+
/// Visit nodes in index notation
56+
class IndexNotationVisitorStrict : public IndexExprVisitorStrict {
57+
public:
58+
virtual ~IndexNotationVisitorStrict();
59+
60+
void visit(const TensorExpr&);
61+
62+
using IndexExprVisitorStrict::visit;
63+
64+
virtual void visit(const AssignmentNode*) = 0;
65+
virtual void visit(const ForallNode*) = 0;
66+
virtual void visit(const WhereNode*) = 0;
67+
};
6168

6269
/// Visit nodes in an expression.
63-
class ExprVisitor : public ExprVisitorStrict {
70+
class IndexNotationVisitor : public IndexExprVisitorStrict {
6471
public:
65-
virtual ~ExprVisitor();
72+
virtual ~IndexNotationVisitor();
6673

67-
using ExprVisitorStrict::visit;
74+
using IndexExprVisitorStrict::visit;
6875

69-
// Scalar Index Expressions
76+
// Index Expressions
7077
virtual void visit(const AccessNode* op);
7178
virtual void visit(const NegNode* op);
7279
virtual void visit(const SqrtNode* op);
@@ -109,10 +116,10 @@ void visit(const Rule* op) { \
109116
Rule##CtxFunc(op, this); \
110117
return; \
111118
} \
112-
ExprVisitor::visit(op); \
119+
IndexNotationVisitor::visit(op); \
113120
}
114121

115-
class Matcher : public ExprVisitor {
122+
class Matcher : public IndexNotationVisitor {
116123
public:
117124
template <class IndexExpr>
118125
void match(IndexExpr indexExpr) {
@@ -132,7 +139,7 @@ class Matcher : public ExprVisitor {
132139
unpack(rest...);
133140
}
134141

135-
using ExprVisitor::visit;
142+
using IndexNotationVisitor::visit;
136143
RULE(AccessNode)
137144
RULE(NegNode)
138145
RULE(SqrtNode)

include/taco/index_notation/index_notation.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class IndexExpr : public util::IntrusivePtr<const ExprNode> {
112112
const Schedule& getSchedule() const;
113113

114114
/// Visit the index expression's sub-expressions.
115-
void accept(ExprVisitorStrict *) const;
115+
void accept(IndexExprVisitorStrict *) const;
116116

117117
/// Print the index expression.
118118
friend std::ostream& operator<<(std::ostream&, const IndexExpr&);
@@ -213,7 +213,7 @@ class TensorExpr : public util::IntrusivePtr<const TensorExprNode> {
213213
TensorExpr(const TensorExprNode* n);
214214

215215
/// Visit the tensor expression
216-
void accept(ExprVisitorStrict *) const;
216+
void accept(IndexNotationVisitorStrict *) const;
217217
};
218218

219219
std::ostream& operator<<(std::ostream&, const TensorExpr&);

src/index_notation/expr_nodes.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ using namespace std;
88
namespace taco {
99

1010
vector<TensorVar> getOperands(const IndexExpr& expr) {
11-
struct GetOperands : public ExprVisitor {
12-
using ExprVisitor::visit;
11+
struct GetOperands : public IndexNotationVisitor {
12+
using IndexNotationVisitor::visit;
1313
set<TensorVar> inserted;
1414
vector<TensorVar> operands;
1515
void visit(const AccessNode* node) {

0 commit comments

Comments
 (0)