|
42 | 42 | * identifies the pattern without applying fusion.
|
43 | 43 | */
|
44 | 44 | public class RewriteQuantizationFusedCompression extends HopRewriteRule {
|
45 |
| - @Override |
46 |
| - public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { |
47 |
| - if(roots == null) |
48 |
| - return null; |
49 |
| - |
50 |
| - // traverse the HOP DAG |
51 |
| - HashMap<String, Hop> floors = new HashMap<>(); |
52 |
| - HashMap<String, Hop> compresses = new HashMap<>(); |
53 |
| - for(Hop h : roots) |
54 |
| - collectFloorCompressSequences(h, floors, compresses); |
55 |
| - |
56 |
| - Hop.resetVisitStatus(roots); |
57 |
| - |
58 |
| - // check compresses for compress-after-floor pattern |
59 |
| - for(Entry<String, Hop> e : compresses.entrySet()) { |
60 |
| - String inputname = e.getKey(); |
61 |
| - Hop compresshop = e.getValue(); |
62 |
| - |
63 |
| - if(floors.containsKey(inputname) // floors same name |
64 |
| - && ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) || |
65 |
| - (floors.get(inputname).getEndLine() < compresshop.getEndLine()) || |
66 |
| - (floors.get(inputname).getBeginLine() == compresshop.getBeginLine() && |
67 |
| - floors.get(inputname).getEndLine() == compresshop.getBeginLine() && |
68 |
| - floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) { |
69 |
| - |
70 |
| - // retrieve the floor hop and inputs |
71 |
| - Hop floorhop = floors.get(inputname); |
72 |
| - Hop floorInput = floorhop.getInput().get(0); |
73 |
| - |
74 |
| - // check if the input of the floor operation is a matrix |
75 |
| - if(floorInput.getDataType() == DataType.MATRIX) { |
76 |
| - |
77 |
| - // Check if the input of the floor operation involves a multiplication operation |
78 |
| - if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) { |
79 |
| - Hop initialMatrix = floorInput.getInput().get(0); |
80 |
| - Hop sf = floorInput.getInput().get(1); |
81 |
| - |
82 |
| - // create fused hop |
83 |
| - BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64, |
84 |
| - OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf); |
85 |
| - |
86 |
| - // rewire compress consumers to fusedHop |
87 |
| - List<Hop> parents = new ArrayList<>(compresshop.getParent()); |
88 |
| - for(Hop p : parents) { |
89 |
| - HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop); |
90 |
| - } |
91 |
| - } |
92 |
| - } |
93 |
| - } |
94 |
| - } |
95 |
| - return roots; |
96 |
| - } |
97 |
| - |
98 |
| - @Override |
99 |
| - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { |
100 |
| - // do nothing, floor/compress do not occur in predicates |
101 |
| - return root; |
102 |
| - } |
103 |
| - |
104 |
| - private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) { |
105 |
| - if(hop.isVisited()) |
106 |
| - return; |
107 |
| - |
108 |
| - // process childs |
109 |
| - if(!hop.getInput().isEmpty()) |
110 |
| - for(Hop c : hop.getInput()) |
111 |
| - collectFloorCompressSequences(c, floors, compresses); |
112 |
| - |
113 |
| - // process current hop |
114 |
| - if(hop instanceof UnaryOp) { |
115 |
| - UnaryOp uop = (UnaryOp) hop; |
116 |
| - if(uop.getOp() == OpOp1.FLOOR) { |
117 |
| - floors.put(uop.getName(), uop); |
118 |
| - } |
119 |
| - else if(uop.getOp() == OpOp1.COMPRESS) { |
120 |
| - compresses.put(uop.getInput(0).getName(), uop); |
121 |
| - } |
122 |
| - } |
123 |
| - hop.setVisited(); |
124 |
| - } |
| 45 | + @Override |
| 46 | + public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { |
| 47 | + if(roots == null) |
| 48 | + return null; |
| 49 | + |
| 50 | + // traverse the HOP DAG |
| 51 | + HashMap<String, Hop> floors = new HashMap<>(); |
| 52 | + HashMap<String, Hop> compresses = new HashMap<>(); |
| 53 | + for(Hop h : roots) |
| 54 | + collectFloorCompressSequences(h, floors, compresses); |
| 55 | + |
| 56 | + Hop.resetVisitStatus(roots); |
| 57 | + |
| 58 | + // check compresses for compress-after-floor pattern |
| 59 | + for(Entry<String, Hop> e : compresses.entrySet()) { |
| 60 | + String inputname = e.getKey(); |
| 61 | + Hop compresshop = e.getValue(); |
| 62 | + |
| 63 | + if(floors.containsKey(inputname) // floors same name |
| 64 | + && ((floors.get(inputname).getBeginLine() < compresshop.getBeginLine()) || |
| 65 | + (floors.get(inputname).getEndLine() < compresshop.getEndLine()) || |
| 66 | + (floors.get(inputname).getBeginLine() == compresshop.getBeginLine() && |
| 67 | + floors.get(inputname).getEndLine() == compresshop.getBeginLine() && |
| 68 | + floors.get(inputname).getBeginColumn() < compresshop.getBeginColumn()))) { |
| 69 | + |
| 70 | + // retrieve the floor hop and inputs |
| 71 | + Hop floorhop = floors.get(inputname); |
| 72 | + Hop floorInput = floorhop.getInput().get(0); |
| 73 | + |
| 74 | + // check if the input of the floor operation is a matrix |
| 75 | + if(floorInput.getDataType() == DataType.MATRIX) { |
| 76 | + |
| 77 | + // Check if the input of the floor operation involves a multiplication operation |
| 78 | + if(floorInput instanceof BinaryOp && ((BinaryOp) floorInput).getOp() == OpOp2.MULT) { |
| 79 | + Hop initialMatrix = floorInput.getInput().get(0); |
| 80 | + Hop sf = floorInput.getInput().get(1); |
| 81 | + |
| 82 | + // create fused hop |
| 83 | + BinaryOp fusedhop = new BinaryOp("test", DataType.MATRIX, ValueType.FP64, |
| 84 | + OpOp2.QUANTIZE_COMPRESS, initialMatrix, sf); |
| 85 | + |
| 86 | + // rewire compress consumers to fusedHop |
| 87 | + List<Hop> parents = new ArrayList<>(compresshop.getParent()); |
| 88 | + for(Hop p : parents) { |
| 89 | + HopRewriteUtils.replaceChildReference(p, compresshop, fusedhop); |
| 90 | + } |
| 91 | + } |
| 92 | + } |
| 93 | + } |
| 94 | + } |
| 95 | + return roots; |
| 96 | + } |
| 97 | + |
| 98 | + @Override |
| 99 | + public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { |
| 100 | + // do nothing, floor/compress do not occur in predicates |
| 101 | + return root; |
| 102 | + } |
| 103 | + |
| 104 | + private void collectFloorCompressSequences(Hop hop, HashMap<String, Hop> floors, HashMap<String, Hop> compresses) { |
| 105 | + if(hop.isVisited()) |
| 106 | + return; |
| 107 | + |
| 108 | + // process childs |
| 109 | + if(!hop.getInput().isEmpty()) |
| 110 | + for(Hop c : hop.getInput()) |
| 111 | + collectFloorCompressSequences(c, floors, compresses); |
| 112 | + |
| 113 | + // process current hop |
| 114 | + if(hop instanceof UnaryOp) { |
| 115 | + UnaryOp uop = (UnaryOp) hop; |
| 116 | + if(uop.getOp() == OpOp1.FLOOR) { |
| 117 | + floors.put(uop.getName(), uop); |
| 118 | + } |
| 119 | + else if(uop.getOp() == OpOp1.COMPRESS) { |
| 120 | + compresses.put(uop.getInput(0).getName(), uop); |
| 121 | + } |
| 122 | + } |
| 123 | + hop.setVisited(); |
| 124 | + } |
125 | 125 | }
|
0 commit comments