-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcommon.hpp
100 lines (94 loc) · 3 KB
/
common.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#pragma once
#include <cassert>
#include <utility>
#include <string>
#include <vector>
#include <memory>
#include <fstream>
/*
This global variable style assumes that only one translation unit includes this header.
candidate_words contains the words that could be the final answer.
query_words contains the words that could be used for query.
LEN is the length of all words.
*/
const int LEN = 5, K = 26;
const char START = 'a';
std::vector<std::string> candidate_words, query_words;
using word_id = unsigned int;
using response_t = unsigned int;
const response_t RESPONSE_RANGE = 243; // 3^5
std::vector<std::vector<response_t>> response_table; // first index: query word, second index: candidate word
std::vector<word_id> candidate_query_id; // candidate word id -> query word id
static response_t encode_query_response(response_t result[]) {
response_t ret = 0;
for (int i = 0; i < LEN; ++i) {
ret = ret * 3 + result[i];
}
return ret;
}
static response_t get_query_response(const std::string &query, const std::string &answer) {
response_t response[LEN];
int cnt[K] = {0};
for (int i = 0; i < LEN; ++i) {
if (query[i] == answer[i]) {
response[i] = 0;
} else {
++cnt[answer[i] - START];
}
}
for (int i = 0; i < LEN; ++i) {
if (query[i] != answer[i]) {
if (cnt[query[i] - START] > 0) {
--cnt[query[i] - START];
response[i] = 1;
} else {
response[i] = 2;
}
}
}
return encode_query_response(response);
}
static void load_vocabulary(const std::string &candidate_vocab_filename, const std::string &query_vocab_filename) {
std::ifstream f1{candidate_vocab_filename};
std::string word;
while (f1 >> word) {
assert(word.length() == LEN);
for (char c : word) {
assert(c >= START && c < START + K);
}
candidate_words.push_back(word);
}
std::ifstream f2{query_vocab_filename};
while (f2 >> word) {
assert(word.length() == LEN);
for (char c : word) {
assert(c >= START && c < START + K);
}
query_words.push_back(word);
}
size_t qn = query_words.size(), cn = candidate_words.size();
response_table.resize(qn);
candidate_query_id.resize(cn);
for (word_id q = 0; q < qn; ++q) {
response_table[q].reserve(cn);
for (word_id c = 0; c < cn; ++c) {
if (query_words[q] == candidate_words[c]) {
candidate_query_id[c] = q;
}
response_table[q].push_back(get_query_response(query_words[q], candidate_words[c]));
}
}
}
struct ResponseNode;
struct GuessNode {
int depth;
std::unique_ptr<ResponseNode> child;
word_id best_query;
GuessNode(int d): depth(d){}
};
struct ResponseNode {
int depth;
word_id query_word;
std::vector<std::pair<response_t, std::unique_ptr<GuessNode>>> children;
ResponseNode(int d, word_id q): depth(d), query_word(q) {}
};