-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathAgent.cpp
154 lines (149 loc) · 4.97 KB
/
Agent.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <iostream>
#include <vector>
#include <stdio.h>
#include <stdlib.h>
#include "NeuralNetwork/NeuralNet.hpp"
using namespace std;
class ReplayMemory {
private:
vector<vector<vector<double>>> mem;
int capacity;
public:
vector<double> current_state;
int action;
int reward;
vector<double> next_state;
bool is_done;
ReplayMemory () {}
ReplayMemory (int capacity) {
this->capacity = capacity;
srand(time(NULL));
}
void store (vector<double> current_state,
int action,
int reward,
vector<double> next_state,
bool is_done) {
if (this->mem.size() == capacity) {
this->mem.erase(this->mem.begin());
}
vector<vector<double>> append_mem;
append_mem.push_back(current_state);
append_mem.push_back({(double)action});
append_mem.push_back({(double)reward});
append_mem.push_back(next_state);
append_mem.push_back({(double)is_done});
this->mem.push_back(append_mem);
}
void random () {
int random_num = rand() % this->mem.size();
vector<vector<double>> mem_read = this->mem[random_num];
this->current_state = mem_read[0];
this->action = (int)mem_read[1][0];
this->reward = (int)mem_read[2][0];
this->next_state = mem_read[3];
this->is_done = (bool)mem_read[4][0];
}
};
class Agent {
private:
NeuralNet net = NeuralNet();
NeuralNet target_net = NeuralNet();
ReplayMemory mem = ReplayMemory();
int frameReachProb;
int batches;
int targetFreqUpdate;
bool is_testing = false;
public:
vector<double> last_prediction;
int frames = 0;
Agent (vector<int> layout, string filename) {
this->net = NeuralNet(layout, -1); // the learning rate doesn't matter
this->net.open_params(filename);
this->target_net = this->net;
this->is_testing = true;
}
Agent (vector<int> layout, double lr, int mem_capacity, int frameReachProb, int targetFreqUpdate, int batches) {
this->mem = ReplayMemory(mem_capacity);
this->net = NeuralNet(layout, lr);
this->target_net = net;
this->frameReachProb = frameReachProb;
this->targetFreqUpdate = targetFreqUpdate;
this->batches = batches;
srand(time(NULL));
}
int argmax (vector<double> array) {
int index = 0;
double max_value = array[0];
for (int i = 1; i < array.size(); i++) {
if (array[i] > max_value) {
max_value = array[i];
index = i;
}
}
return index;
}
int action (vector<double> input) {
if (!is_testing) {
this->frames++;
double probability;
if (frames <= frameReachProb) {
probability = (-0.95 / double(frameReachProb)) * frames + 1;
} else {
probability = 0.05;
}
bool isRandom = (rand() % 100) < (probability * 100);
int action;
if (isRandom) {
action = rand() % 3;
last_prediction = vector<double>({-1, -1, -1});
} else {
last_prediction = this->net.predict(input);
action = argmax(last_prediction);
}
return action;
} else {
last_prediction = this->net.predict(input);
return argmax(last_prediction);
}
}
void store_mem (vector<double> current_state, int action, int reward, vector<double> next_state, bool is_done) {
if (is_testing) throw invalid_argument("Cannot use this function while testing");
mem.store(current_state, action, reward, next_state, is_done);
}
double max (vector<double> array) {
double max_val = array[0];
for (int i = 1; i < array.size(); i++) {
if (array[i] > max_val) {
max_val = array[i];
}
}
return max_val;
}
void train () {
if (is_testing) throw invalid_argument("Cannot use this function while testing");
// sample minibatch
for (int i = 0; i < batches; i++) {
mem.random();
vector<double> current_state = mem.current_state;
int action = mem.action;
int reward = mem.reward;
vector<double> next_state = mem.next_state;
bool is_done = mem.is_done;
// train
double y;
if (is_done) {
y = reward;
} else {
y = reward + (0.99 * max(this->target_net.predict(next_state)));
}
vector<double> target = this->net.predict(current_state);
target[action] = y;
this->net.backprop(current_state, target);
}
if (frames % this->targetFreqUpdate == 0) {
this->target_net = this->net;
this->net.save_params("model.txt"); // yes im using .txt file dont bully me
}
}
};