-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathbsg_scatter_gather.rs
129 lines (114 loc) · 4.09 KB
/
bsg_scatter_gather.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
//! Given a bit vector, generates permutation vectors that perform concentration (fwd) and deconcentration (bkwd).
//!
//! ```text
//! bit fwd bkwd
//! pos. vec vec vec
//! 3 1 --\ 1 --- 2
//! \ /
//! 2 1 -\ -> 3 <-- -- 1
//! \ /
//! 1 0 --> 2 <--- 3 --> 1
//!
//! 0 1 -----> 0 <------ 0
//! ```
//!
//! For empty slots; we just pick an unused slot, possible reusing the same empty slot multiple times. This allows
//! control logic to be unselected.
use shakeflow::*;
use shakeflow_std::*;
#[derive(Debug, Clone, Signal)]
pub struct E<const VEC_SIZE: usize> {
fwd: Array<Bits<Log2<U<VEC_SIZE>>>, U<VEC_SIZE>>,
fwd_datapath: Array<Bits<Log2<U<VEC_SIZE>>>, U<VEC_SIZE>>,
bk: Array<Bits<Log2<U<VEC_SIZE>>>, U<VEC_SIZE>>,
bk_datapath: Array<Bits<Log2<U<VEC_SIZE>>>, U<VEC_SIZE>>,
}
impl<const VEC_SIZE: usize> E<VEC_SIZE> {
pub fn new_expr() -> Expr<Self> {
EProj {
fwd: Expr::<Bits<Log2<U<VEC_SIZE>>>>::from(0).repeat(),
fwd_datapath: Expr::<Bits<Log2<U<VEC_SIZE>>>>::from(0).repeat(),
bk: Expr::<Bits<Log2<U<VEC_SIZE>>>>::from(0).repeat(),
bk_datapath: Expr::<Bits<Log2<U<VEC_SIZE>>>>::from(0).repeat(),
}
.into()
}
}
pub type IC<const VEC_SIZE: usize> = UniChannel<Bits<U<VEC_SIZE>>>;
pub type EC<const VEC_SIZE: usize> = UniChannel<E<VEC_SIZE>>;
/// Converts from an integer to a list of bit integers.
///
/// For example, `int_to_bit_list(1, 3)` is `[0, 0, 1]` and `int_to_bit_list(6, 3)` is `[1, 1, 0]`.
#[inline]
fn int_to_bit_list(a: usize, pad: usize) -> Vec<bool> {
let mut b = format!("{a:b}").chars().map(|ch| ch == '1').collect::<Vec<bool>>();
let mut v = vec![false; pad - b.len()];
v.append(&mut b);
v
}
/// Corresponds to `print_case_line`.
#[inline]
fn result_to_case_rhs<const VEC_SIZE: usize>(result: Vec<usize>) -> Expr<Array<Bits<Log2<U<VEC_SIZE>>>, U<VEC_SIZE>>> {
let case_rhs: [Expr<Bits<Log2<U<VEC_SIZE>>>>; VEC_SIZE] = result
.into_iter()
.map(|a| {
let mut rhs = int_to_bit_list(a, clog2(VEC_SIZE));
rhs.resize(VEC_SIZE, false);
Expr::<Bits<U<VEC_SIZE>>>::from(<[bool; VEC_SIZE] as TryFrom<_>>::try_from(rhs).unwrap()).resize()
})
.collect::<Vec<_>>()
.try_into()
.unwrap();
case_rhs.into()
}
/// Corresponds to `gen_{fwd|back}_vec_line_helper`.
#[inline]
fn gen_case_expr<const VEC_SIZE: usize>(a: usize) -> Expr<E<VEC_SIZE>> {
let l = int_to_bit_list(a, VEC_SIZE);
let mut fwd = vec![];
let mut fwd_datapath = vec![];
let mut bk = vec![];
let mut bk_datapath = vec![];
// Initial value different from original code but doesn't matter;
// if `spare` is never updated, it is never used
let mut spare = 0;
let mut pos = 0;
let mut i_bk = 0;
for (i, x) in l.into_iter().rev().enumerate() {
if x {
fwd.push(i);
fwd_datapath.push(i - pos);
pos += 1;
bk.push(i_bk);
bk_datapath.push(i_bk);
i_bk += 1;
} else {
spare = i;
bk.push(VEC_SIZE - 1);
bk_datapath.push(0);
}
}
for _ in 0..(VEC_SIZE - fwd.len()) {
fwd.push(spare);
fwd_datapath.push(0);
}
EProj {
fwd: result_to_case_rhs(fwd),
fwd_datapath: result_to_case_rhs(fwd_datapath),
bk: result_to_case_rhs(bk),
bk_datapath: result_to_case_rhs(bk_datapath),
}
.into()
}
pub fn m<const VEC_SIZE: usize>() -> Module<IC<VEC_SIZE>, EC<VEC_SIZE>> {
composite::<IC<VEC_SIZE>, EC<VEC_SIZE>, _>("bsg_scatter_gather", Some("i"), Some("o"), |input, k| {
input.fsm_map::<(), _, _>(k, None, ().into(), |input, state| {
let case_items = (0..(1 << VEC_SIZE))
.map(|i| (Expr::<Bits<U<VEC_SIZE>>>::from(i), gen_case_expr(i)))
.collect::<Vec<_>>();
let output = input.case(case_items, Some(Expr::x()));
(output, state)
})
})
.build()
}