-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathselect_from_weighted.rs
More file actions
199 lines (170 loc) · 6.99 KB
/
select_from_weighted.rs
File metadata and controls
199 lines (170 loc) · 6.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
use rand::distributions::uniform::SampleUniform;
use crate::{int_in_range, integers::Uint};
/// Selects one element from a given weighted list.
///
/// In contrast to [`pick`] this does not move the selected element from the input list
/// but requires elements to be `Clone`able. This is because only one element is needed.
/// It could be implemented differently though.
///
/// The list must not be empty. Each element must have a non-zeo weight.
/// The total weight must not exceed the u128 range.
///
/// ## Examples
///
/// Pick 1 hat out of 3 hats with different rarity:
///
/// ```
/// use nois::{randomness_from_str, select_from_weighted};
///
/// let randomness = randomness_from_str("9e8e26615f51552aa3b18b6f0bcf0dae5afbe30321e8d7ea7fa51ebeb1d8fe62").unwrap();
///
/// let list = vec![
/// ("green hat", 40u32),
/// ("viking helmet", 55u32),
/// ("rare golden crown", 5u32)
/// ];
///
/// let selected = select_from_weighted(randomness, &list).unwrap();
///
/// assert_eq!(selected, "viking helmet");
/// ```
pub fn select_from_weighted<T: Clone, W: Uint + SampleUniform>(
randomness: [u8; 32],
list: &[(T, W)],
) -> Result<T, String> {
if list.is_empty() {
return Err(String::from("List must not be empty"));
}
let mut total_weight = W::ZERO;
for (_, weight) in list {
if *weight == W::ZERO {
return Err(String::from("All element weights should be >= 1"));
}
total_weight = total_weight
.checked_add(*weight)
.ok_or_else(|| String::from("Total weight is greater than maximum value"))?;
}
debug_assert!(
total_weight > W::ZERO,
"we know we have a non-empty list of non-zero elements"
);
let r = int_in_range::<W>(randomness, W::ONE, total_weight);
let mut weight_sum = W::ZERO;
for element in list {
weight_sum += element.1;
if r <= weight_sum {
return Ok(element.0.clone());
}
}
// This point should never be reached
panic!("No element selected")
}
#[cfg(test)]
mod tests {
use crate::RANDOMNESS1;
use super::*;
#[test]
fn select_from_weighted_works() {
let elements: Vec<(char, u32)> = vec![('a', 1), ('b', 5), ('c', 4)];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, 'c');
// Element type is Clone but not Copy
#[derive(PartialEq, Debug, Clone)]
struct Color(String);
let elements = vec![
(Color("red".into()), 12u32),
(Color("blue".to_string()), 15u32),
(Color("green".to_string()), 8u32),
(Color("orange".to_string()), 21u32),
(Color("pink".to_string()), 11u32),
];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, Color("orange".to_string()));
// Test for u128
let elements = vec![
(Color("red".into()), 12u128),
(Color("blue".to_string()), 15u128),
(Color("green".to_string()), 8u128),
(Color("orange".to_string()), 21u128),
(Color("pink".to_string()), 11u128),
];
let picked = select_from_weighted(RANDOMNESS1, &elements).unwrap();
assert_eq!(picked, Color("blue".to_string()));
// Pick from slice
let selection = &elements[0..3];
let picked = select_from_weighted(RANDOMNESS1, selection).unwrap();
assert_eq!(picked, Color("red".to_string()));
}
#[test]
fn select_from_weighted_fails_on_empty_list() {
//This will check that the list is empty
let elements: Vec<(i32, u32)> = vec![];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
// Check that the selected element has the expected weight
assert_eq!(err, "List must not be empty");
}
#[test]
fn select_from_weighted_fails_on_element_weight_less_than_1() {
let elements: Vec<(i32, u32)> = vec![(1, 5), (2, 4), (-3, 0)];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
// Check that the selected element has the expected weight
assert_eq!(err, "All element weights should be >= 1");
}
#[test]
fn select_from_weighted_fails_with_total_weight_too_high() {
let elements: Vec<(i32, u128)> = vec![(1, u128::MAX), (2, 1)];
let err = select_from_weighted(RANDOMNESS1, &elements).unwrap_err();
// Check that the selected element has the expected weight
assert_eq!(err, "Total weight is greater than maximum value of u128");
}
#[test]
fn select_from_weighted_passes_with_total_weight_not_too_high() {
let elements: Vec<(i32, u128)> = vec![(1, u32::MAX.into()), (2, 1)];
select_from_weighted(RANDOMNESS1, &elements).unwrap();
}
#[test]
fn select_from_weighted_distribution_is_uniform() {
/// This test will generate a huge amount of subrandomness
/// then checks that the distribution is expected within a range of 1%
use crate::sub_randomness::sub_randomness;
use std::collections::HashMap;
const TEST_SAMPLE_SIZE: usize = 1_000_000;
const ACCURACY: f32 = 0.01;
// This test needs the sum of the weights to be equal to 1.
// Although the function should work as expected for weights that do not equal 1
let elements: Vec<(String, u32)> = vec![
(String::from("a"), 100),
(String::from("b"), 200),
(String::from("c"), 30),
(String::from("d"), 70),
(String::from("e"), 600),
];
let total_weight = elements.iter().map(|element| element.1).sum::<u32>();
println!("total weight: {}", total_weight);
let mut result = vec![];
for subrand in sub_randomness(RANDOMNESS1).take(TEST_SAMPLE_SIZE) {
result.push(select_from_weighted(subrand, &elements).unwrap());
}
let mut histogram = HashMap::new();
for element in result {
let count = histogram.entry(element).or_insert(0);
*count += 1;
}
// This will assert on all the elements of the data 1 by 1 and check if their occurence is within the 1% expected range
for (bin, count) in histogram {
let probability = elements.iter().find(|e| e.0 == bin).map(|e| e.1).unwrap() as f32
/ total_weight as f32;
let estimated_count_for_uniform_distribution = TEST_SAMPLE_SIZE as f32 * probability;
let estimation_min: i32 =
(estimated_count_for_uniform_distribution * (1_f32 - ACCURACY)) as i32;
let estimation_max: i32 =
(estimated_count_for_uniform_distribution * (1_f32 + ACCURACY)) as i32;
println!(
"estimation {}, max: {}, min: {}",
estimated_count_for_uniform_distribution, estimation_max, estimation_min
);
println!("{}: {}", bin, count);
assert!(count >= estimation_min && count <= estimation_max);
}
}
}